mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
test: migrate unit tests from unittest to pytest
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
ff9d4d8a9d
commit
9331253894
8 changed files with 1440 additions and 1405 deletions
|
|
@ -4,8 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -31,258 +29,259 @@ MODEL = "Llama3.1-8B-Instruct"
|
|||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
asyncio.get_running_loop().set_debug(False)
|
||||
async def test_system_default():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[-1].content == content
|
||||
assert "Cutting Knowledge Date: December 2023" in messages[0].content
|
||||
|
||||
async def test_system_default(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
||||
async def test_system_builtin_only(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
async def test_system_builtin_only():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[-1].content == content
|
||||
assert "Cutting Knowledge Date: December 2023" in messages[0].content
|
||||
assert "Tools: brave_search" in messages[0].content
|
||||
|
||||
async def test_system_custom_only(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
async def test_system_custom_only():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 3
|
||||
assert "Environment: ipython" in messages[0].content
|
||||
|
||||
async def test_system_custom_and_builtin(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
assert "Return function calls in JSON format" in messages[1].content
|
||||
assert messages[-1].content == content
|
||||
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_completion_message_encoding(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL3_2,
|
||||
messages=[
|
||||
UserMessage(content="hello"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments={"param1": "value1"},
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn('[custom1(param1="value1")]', prompt)
|
||||
|
||||
request.model = MODEL
|
||||
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn(
|
||||
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
|
||||
prompt,
|
||||
)
|
||||
|
||||
async def test_user_provided_system_message(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_repalce_system_message_behavior_builtin_tools(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
async def test_system_custom_and_builtin():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 3
|
||||
|
||||
async def test_repalce_system_message_behavior_custom_tools(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
assert "Environment: ipython" in messages[0].content
|
||||
assert "Tools: brave_search" in messages[0].content
|
||||
|
||||
assert "Return function calls in JSON format" in messages[1].content
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_completion_message_encoding():
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL3_2,
|
||||
messages=[
|
||||
UserMessage(content="hello"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments={"param1": "value1"},
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools_with_template(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate {{ function_description }}"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
assert '[custom1(param1="value1")]' in prompt
|
||||
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertIn("You are a pirate", messages[0].content)
|
||||
# function description is present in the system prompt
|
||||
self.assertIn('"name": "custom1"', messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
request.model = MODEL
|
||||
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
||||
|
||||
|
||||
async def test_user_provided_system_message():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content.endswith(system_prompt)
|
||||
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_repalce_system_message_behavior_builtin_tools():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content.endswith(system_prompt)
|
||||
assert "Environment: ipython" in messages[0].content
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_repalce_system_message_behavior_custom_tools():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content.endswith(system_prompt)
|
||||
assert "Environment: ipython" in messages[0].content
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools_with_template():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate {{ function_description }}"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert "Environment: ipython" in messages[0].content
|
||||
assert "You are a pirate" in messages[0].content
|
||||
# function description is present in the system prompt
|
||||
assert '"name": "custom1"' in messages[0].content
|
||||
assert messages[-1].content == content
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||
|
|
@ -24,59 +23,61 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
|||
)
|
||||
|
||||
|
||||
class PromptTemplateTests(unittest.TestCase):
|
||||
def check_generator_output(self, generator):
|
||||
for example in generator.data_examples():
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
# print(text) # debugging
|
||||
if not example:
|
||||
continue
|
||||
for tool in example:
|
||||
assert tool.tool_name in text
|
||||
|
||||
def test_system_default(self):
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||
|
||||
def test_system_builtin_only(self):
|
||||
generator = BuiltinToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||
|
||||
def test_system_custom_only(self):
|
||||
self.maxDiff = None
|
||||
generator = JsonCustomToolGenerator()
|
||||
self.check_generator_output(generator)
|
||||
|
||||
def test_system_custom_function_tag(self):
|
||||
self.maxDiff = None
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
self.check_generator_output(generator)
|
||||
|
||||
def test_llama_3_2_system_zero_shot(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
self.check_generator_output(generator)
|
||||
|
||||
def test_llama_3_2_provided_system_prompt(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
user_system_prompt = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
{{ function_description }}
|
||||
"""
|
||||
)
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example, user_system_prompt)
|
||||
def check_generator_output(generator):
|
||||
for example in generator.data_examples():
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
assert "Overriding message." in text
|
||||
assert '"name": "get_weather"' in text
|
||||
if not example:
|
||||
continue
|
||||
for tool in example:
|
||||
assert tool.tool_name in text
|
||||
|
||||
|
||||
def test_system_default():
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||
|
||||
|
||||
def test_system_builtin_only():
|
||||
generator = BuiltinToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||
|
||||
|
||||
def test_system_custom_only():
|
||||
generator = JsonCustomToolGenerator()
|
||||
check_generator_output(generator)
|
||||
|
||||
|
||||
def test_system_custom_function_tag():
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
check_generator_output(generator)
|
||||
|
||||
|
||||
def test_llama_3_2_system_zero_shot():
|
||||
generator = PythonListCustomToolGenerator()
|
||||
check_generator_output(generator)
|
||||
|
||||
|
||||
def test_llama_3_2_provided_system_prompt():
|
||||
generator = PythonListCustomToolGenerator()
|
||||
user_system_prompt = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
{{ function_description }}
|
||||
"""
|
||||
)
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example, user_system_prompt)
|
||||
text = pt.render()
|
||||
assert "Overriding message." in text
|
||||
assert '"name": "get_weather"' in text
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -15,124 +14,125 @@ from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIO
|
|||
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
|
||||
|
||||
|
||||
class TestNvidiaDatastore(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||
@pytest.fixture
|
||||
def nvidia_dataset_adapter():
|
||||
"""Set up the NVIDIA dataset adapter with mocked dependencies"""
|
||||
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||
|
||||
config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
|
||||
)
|
||||
self.adapter = NvidiaDatasetIOAdapter(config)
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
|
||||
)
|
||||
adapter = NvidiaDatasetIOAdapter(config)
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
with patch(
|
||||
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||
) as mock_make_request:
|
||||
yield adapter, mock_make_request
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
def assert_request(mock_call, expected_method, expected_path, expected_json=None):
|
||||
"""Helper function to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
assert call_args[0][0] == expected_method
|
||||
assert call_args[0][1] == expected_path
|
||||
assert call_args[0][0] == expected_method
|
||||
assert call_args[0][1] == expected_path
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_register_dataset(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
|
||||
async def test_register_dataset(nvidia_dataset_adapter):
|
||||
adapter, mock_make_request = nvidia_dataset_adapter
|
||||
|
||||
mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
}
|
||||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type="dataset",
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
|
||||
)
|
||||
|
||||
await adapter.register_dataset(dataset_def)
|
||||
|
||||
mock_make_request.assert_called_once()
|
||||
assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
}
|
||||
"files_url": "https://example.com/data.jsonl",
|
||||
"project": "default",
|
||||
"format": "jsonl",
|
||||
"description": "Test dataset description",
|
||||
},
|
||||
)
|
||||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type="dataset",
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
|
||||
)
|
||||
|
||||
self.run_async(self.adapter.register_dataset(dataset_def))
|
||||
async def test_unregister_dataset(nvidia_dataset_adapter):
|
||||
adapter, mock_make_request = nvidia_dataset_adapter
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
"files_url": "https://example.com/data.jsonl",
|
||||
"project": "default",
|
||||
"format": "jsonl",
|
||||
"description": "Test dataset description",
|
||||
},
|
||||
)
|
||||
mock_make_request.return_value = {
|
||||
"message": "Resource deleted successfully.",
|
||||
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
|
||||
"deleted_at": None,
|
||||
}
|
||||
dataset_id = "test-dataset"
|
||||
|
||||
def test_unregister_dataset(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"message": "Resource deleted successfully.",
|
||||
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
|
||||
"deleted_at": None,
|
||||
}
|
||||
dataset_id = "test-dataset"
|
||||
await adapter.unregister_dataset(dataset_id)
|
||||
|
||||
self.run_async(self.adapter.unregister_dataset(dataset_id))
|
||||
mock_make_request.assert_called_once()
|
||||
assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
|
||||
|
||||
def test_register_dataset_with_custom_namespace_project(self):
|
||||
custom_config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
|
||||
dataset_namespace="custom-namespace",
|
||||
project_id="custom-project",
|
||||
)
|
||||
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
|
||||
async def test_register_dataset_with_custom_namespace_project(nvidia_dataset_adapter):
|
||||
adapter, mock_make_request = nvidia_dataset_adapter
|
||||
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
custom_config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
|
||||
dataset_namespace="custom-namespace",
|
||||
project_id="custom-project",
|
||||
)
|
||||
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
|
||||
|
||||
mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
"name": "test-dataset",
|
||||
"namespace": "custom-namespace",
|
||||
}
|
||||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type="dataset",
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||
metadata={"format": "jsonl"},
|
||||
)
|
||||
|
||||
await custom_adapter.register_dataset(dataset_def)
|
||||
|
||||
mock_make_request.assert_called_once()
|
||||
assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
"name": "test-dataset",
|
||||
"namespace": "custom-namespace",
|
||||
}
|
||||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type="dataset",
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||
metadata={"format": "jsonl"},
|
||||
)
|
||||
|
||||
self.run_async(custom_adapter.register_dataset(dataset_def))
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
"name": "test-dataset",
|
||||
"namespace": "custom-namespace",
|
||||
"files_url": "https://example.com/data.jsonl",
|
||||
"project": "custom-project",
|
||||
"format": "jsonl",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
"files_url": "https://example.com/data.jsonl",
|
||||
"project": "custom-project",
|
||||
"format": "jsonl",
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -21,181 +20,186 @@ MOCK_DATASET_ID = "default/test-dataset"
|
|||
MOCK_BENCHMARK_ID = "test-benchmark"
|
||||
|
||||
|
||||
class TestNVIDIAEvalImpl(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
|
||||
@pytest.fixture
|
||||
def nvidia_eval_impl():
|
||||
"""Set up the NVIDIA eval implementation with mocked dependencies"""
|
||||
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
|
||||
|
||||
# Create mock APIs
|
||||
self.datasetio_api = MagicMock()
|
||||
self.datasets_api = MagicMock()
|
||||
self.scoring_api = MagicMock()
|
||||
self.inference_api = MagicMock()
|
||||
self.agents_api = MagicMock()
|
||||
# Create mock APIs
|
||||
datasetio_api = MagicMock()
|
||||
datasets_api = MagicMock()
|
||||
scoring_api = MagicMock()
|
||||
inference_api = MagicMock()
|
||||
agents_api = MagicMock()
|
||||
|
||||
self.config = NVIDIAEvalConfig(
|
||||
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
||||
)
|
||||
config = NVIDIAEvalConfig(
|
||||
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
||||
)
|
||||
|
||||
self.eval_impl = NVIDIAEvalImpl(
|
||||
config=self.config,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
scoring_api=self.scoring_api,
|
||||
inference_api=self.inference_api,
|
||||
agents_api=self.agents_api,
|
||||
)
|
||||
eval_impl = NVIDIAEvalImpl(
|
||||
config=config,
|
||||
datasetio_api=datasetio_api,
|
||||
datasets_api=datasets_api,
|
||||
scoring_api=scoring_api,
|
||||
inference_api=inference_api,
|
||||
agents_api=agents_api,
|
||||
)
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.evaluator_get_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
|
||||
)
|
||||
self.evaluator_post_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
|
||||
)
|
||||
# Mock the HTTP request methods
|
||||
with (
|
||||
patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get") as mock_evaluator_get,
|
||||
patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post") as mock_evaluator_post,
|
||||
):
|
||||
yield eval_impl, mock_evaluator_get, mock_evaluator_post
|
||||
|
||||
self.mock_evaluator_get = self.evaluator_get_patcher.start()
|
||||
self.mock_evaluator_post = self.evaluator_post_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.evaluator_get_patcher.stop()
|
||||
self.evaluator_post_patcher.stop()
|
||||
def assert_request_body(mock_evaluator_post, expected_json):
|
||||
"""Helper function to verify request body in Evaluator POST request is correct"""
|
||||
call_args = mock_evaluator_post.call_args
|
||||
actual_json = call_args[0][1]
|
||||
|
||||
def _assert_request_body(self, expected_json):
|
||||
"""Helper method to verify request body in Evaluator POST request is correct"""
|
||||
call_args = self.mock_evaluator_post.call_args
|
||||
actual_json = call_args[0][1]
|
||||
# Check that all expected keys contain the expected values in the actual JSON
|
||||
for key, value in expected_json.items():
|
||||
assert key in actual_json, f"Key '{key}' missing in actual JSON"
|
||||
|
||||
# Check that all expected keys contain the expected values in the actual JSON
|
||||
for key, value in expected_json.items():
|
||||
assert key in actual_json, f"Key '{key}' missing in actual JSON"
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
|
||||
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
|
||||
else:
|
||||
assert actual_json[key] == value, f"Value mismatch for '{key}'"
|
||||
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
|
||||
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
|
||||
else:
|
||||
assert actual_json[key] == value, f"Value mismatch for '{key}'"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
async def test_register_benchmark(nvidia_eval_impl):
|
||||
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
|
||||
|
||||
def test_register_benchmark(self):
|
||||
eval_config = {
|
||||
"type": "custom",
|
||||
"params": {"parallelism": 8},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
|
||||
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
|
||||
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
benchmark = Benchmark(
|
||||
provider_id="nvidia",
|
||||
type="benchmark",
|
||||
identifier=MOCK_BENCHMARK_ID,
|
||||
dataset_id=MOCK_DATASET_ID,
|
||||
scoring_functions=["basic::equality"],
|
||||
metadata=eval_config,
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Register the benchmark
|
||||
self.run_async(self.eval_impl.register_benchmark(benchmark))
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
|
||||
|
||||
def test_run_eval(self):
|
||||
benchmark_config = BenchmarkConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
type="model",
|
||||
model=CoreModelId.llama3_1_8b_instruct.value,
|
||||
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
|
||||
)
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "created"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Run the Evaluation job
|
||||
result = self.run_async(
|
||||
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
|
||||
)
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body(
|
||||
{
|
||||
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
|
||||
eval_config = {
|
||||
"type": "custom",
|
||||
"params": {"parallelism": 8},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
|
||||
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
|
||||
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
benchmark = Benchmark(
|
||||
provider_id="nvidia",
|
||||
type="benchmark",
|
||||
identifier=MOCK_BENCHMARK_ID,
|
||||
dataset_id=MOCK_DATASET_ID,
|
||||
scoring_functions=["basic::equality"],
|
||||
metadata=eval_config,
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
|
||||
mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Register the benchmark
|
||||
await eval_impl.register_benchmark(benchmark)
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
mock_evaluator_post.assert_called_once()
|
||||
assert_request_body(
|
||||
mock_evaluator_post, {"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config}
|
||||
)
|
||||
|
||||
|
||||
async def test_run_eval(nvidia_eval_impl):
|
||||
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
|
||||
|
||||
benchmark_config = BenchmarkConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
type="model",
|
||||
model=CoreModelId.llama3_1_8b_instruct.value,
|
||||
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
|
||||
)
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.in_progress
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "created"}
|
||||
mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
def test_job_status(self):
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "completed"}
|
||||
self.mock_evaluator_get.return_value = mock_evaluator_response
|
||||
# Run the Evaluation job
|
||||
result = await eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
|
||||
|
||||
# Get the Evaluation job
|
||||
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
# Verify the Evaluator API was called correctly
|
||||
mock_evaluator_post.assert_called_once()
|
||||
assert_request_body(
|
||||
mock_evaluator_post,
|
||||
{
|
||||
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.completed
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.in_progress
|
||||
|
||||
# Verify the API was called correctly
|
||||
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
|
||||
|
||||
def test_job_cancel(self):
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
async def test_job_status(nvidia_eval_impl):
|
||||
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
|
||||
|
||||
# Cancel the Evaluation job
|
||||
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "completed"}
|
||||
mock_evaluator_get.return_value = mock_evaluator_response
|
||||
|
||||
# Verify the API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
|
||||
# Get the Evaluation job
|
||||
result = await eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
|
||||
|
||||
def test_job_result(self):
|
||||
# Mock Evaluator API responses
|
||||
mock_job_status_response = {"id": "job-123", "status": "completed"}
|
||||
mock_job_results_response = {
|
||||
"id": "job-123",
|
||||
"status": "completed",
|
||||
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
|
||||
}
|
||||
self.mock_evaluator_get.side_effect = [
|
||||
mock_job_status_response, # First call to retrieve job
|
||||
mock_job_results_response, # Second call to retrieve job results
|
||||
]
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.completed
|
||||
|
||||
# Get the Evaluation job results
|
||||
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
# Verify the API was called correctly
|
||||
mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EvaluateResponse)
|
||||
assert MOCK_BENCHMARK_ID in result.scores
|
||||
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
|
||||
|
||||
# Verify the API was called correctly
|
||||
assert self.mock_evaluator_get.call_count == 2
|
||||
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
|
||||
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")
|
||||
async def test_job_cancel(nvidia_eval_impl):
|
||||
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
|
||||
mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Cancel the Evaluation job
|
||||
await eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
|
||||
|
||||
# Verify the API was called correctly
|
||||
mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
|
||||
|
||||
|
||||
async def test_job_result(nvidia_eval_impl):
|
||||
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
|
||||
|
||||
# Mock Evaluator API responses
|
||||
mock_job_status_response = {"id": "job-123", "status": "completed"}
|
||||
mock_job_results_response = {
|
||||
"id": "job-123",
|
||||
"status": "completed",
|
||||
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
|
||||
}
|
||||
mock_evaluator_get.side_effect = [
|
||||
mock_job_status_response, # First call to retrieve job
|
||||
mock_job_results_response, # Second call to retrieve job results
|
||||
]
|
||||
|
||||
# Get the Evaluation job results
|
||||
result = await eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EvaluateResponse)
|
||||
assert MOCK_BENCHMARK_ID in result.scores
|
||||
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
|
||||
|
||||
# Verify the API was called correctly
|
||||
assert mock_evaluator_get.call_count == 2
|
||||
mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
|
||||
mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -27,253 +26,249 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
|||
)
|
||||
|
||||
|
||||
class TestNvidiaParameters(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
@pytest.fixture
|
||||
def nvidia_adapter():
|
||||
"""Set up the NVIDIA adapter with mock configuration"""
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
|
||||
config = NvidiaPostTrainingConfig(
|
||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||
)
|
||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||
config = NvidiaPostTrainingConfig(
|
||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||
)
|
||||
adapter = NvidiaPostTrainingAdapter(config)
|
||||
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
self.mock_make_request.return_value = {
|
||||
# Mock the _make_request method
|
||||
with patch(
|
||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||
) as mock_make_request:
|
||||
mock_make_request.return_value = {
|
||||
"id": "job-123",
|
||||
"status": "created",
|
||||
"created_at": "2025-03-04T13:07:47.543605",
|
||||
"updated_at": "2025-03-04T13:07:47.543605",
|
||||
}
|
||||
yield adapter, mock_make_request
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
|
||||
def _assert_request_params(self, expected_json):
|
||||
"""Helper method to verify parameters in the request JSON."""
|
||||
call_args = self.mock_make_request.call_args
|
||||
actual_json = call_args[1]["json"]
|
||||
def assert_request_params(mock_make_request, expected_json):
|
||||
"""Helper function to verify parameters in the request JSON."""
|
||||
call_args = mock_make_request.call_args
|
||||
actual_json = call_args[1]["json"]
|
||||
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert actual_json[key][nested_key] == nested_value
|
||||
else:
|
||||
assert actual_json[key] == value
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert actual_json[key][nested_key] == nested_value
|
||||
else:
|
||||
assert actual_json[key] == value
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def test_customizer_parameters_passed(self):
|
||||
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
async def test_customizer_parameters_passed(nvidia_adapter):
|
||||
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
||||
adapter, mock_make_request = nvidia_adapter
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0002,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=3,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
await adapter.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0002,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=3,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
fields = [
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
fields = [
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
self._assert_request_params(
|
||||
{
|
||||
"hyperparameters": {
|
||||
"lora": {"alpha": 16},
|
||||
"epochs": 3,
|
||||
"learning_rate": 0.0002,
|
||||
"batch_size": 16,
|
||||
}
|
||||
assert_request_params(
|
||||
mock_make_request,
|
||||
{
|
||||
"hyperparameters": {
|
||||
"lora": {"alpha": 16},
|
||||
"epochs": 3,
|
||||
"learning_rate": 0.0002,
|
||||
"batch_size": 16,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_required_parameters_passed(nvidia_adapter):
|
||||
"""Test scenario 2: When required parameters are passed."""
|
||||
adapter, mock_make_request = nvidia_adapter
|
||||
|
||||
required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40"
|
||||
required_dataset_id = "required-dataset"
|
||||
required_job_uuid = "required-job"
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
await adapter.supervised_fine_tune(
|
||||
job_uuid=required_job_uuid, # Required parameter
|
||||
model=required_model, # Required parameter
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
def test_required_parameters_passed(self):
|
||||
"""Test scenario 2: When required parameters are passed."""
|
||||
required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40"
|
||||
required_dataset_id = "required-dataset"
|
||||
required_job_uuid = "required-job"
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
fields = [
|
||||
"rank",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
mock_make_request.assert_called_once()
|
||||
call_args = mock_make_request.call_args
|
||||
|
||||
assert call_args[1]["json"]["config"] == required_model
|
||||
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
||||
|
||||
|
||||
async def test_unsupported_parameters_warning(nvidia_adapter):
|
||||
"""Test that warnings are raised for unsupported parameters."""
|
||||
adapter, mock_make_request = nvidia_adapter
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="test-dataset",
|
||||
batch_size=8,
|
||||
# Unsupported parameters
|
||||
shuffle=True,
|
||||
data_format=DatasetFormat.instruct,
|
||||
validation_dataset_id="val-dataset",
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
# Unsupported parameters
|
||||
optimizer_type=OptimizerType.adam,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
efficiency_config = EfficiencyConfig(
|
||||
enable_activation_checkpointing=True # Unsupported parameter
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
# Unsupported parameters
|
||||
efficiency_config=efficiency_config,
|
||||
max_steps_per_epoch=1000,
|
||||
gradient_accumulation_steps=4,
|
||||
max_validation_steps=100,
|
||||
dtype="bf16",
|
||||
)
|
||||
|
||||
# Capture warnings
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
await adapter.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="test-dir", # Unsupported parameter
|
||||
algorithm_config=LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
),
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={"test": "value"}, # Unsupported parameter
|
||||
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
assert len(w) >= 4
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid=required_job_uuid, # Required parameter
|
||||
model=required_model, # Required parameter
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
fields = [
|
||||
"rank",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
call_args = self.mock_make_request.call_args
|
||||
|
||||
assert call_args[1]["json"]["config"] == required_model
|
||||
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
||||
|
||||
def test_unsupported_parameters_warning(self):
|
||||
"""Test that warnings are raised for unsupported parameters."""
|
||||
data_config = DataConfig(
|
||||
dataset_id="test-dataset",
|
||||
batch_size=8,
|
||||
# Unsupported parameters
|
||||
shuffle=True,
|
||||
data_format=DatasetFormat.instruct,
|
||||
validation_dataset_id="val-dataset",
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
# Unsupported parameters
|
||||
optimizer_type=OptimizerType.adam,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
efficiency_config = EfficiencyConfig(
|
||||
enable_activation_checkpointing=True # Unsupported parameter
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
# Unsupported parameters
|
||||
efficiency_config=efficiency_config,
|
||||
max_steps_per_epoch=1000,
|
||||
gradient_accumulation_steps=4,
|
||||
max_validation_steps=100,
|
||||
dtype="bf16",
|
||||
)
|
||||
|
||||
# Capture warnings
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="test-dir", # Unsupported parameter
|
||||
algorithm_config=LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
),
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={"test": "value"}, # Unsupported parameter
|
||||
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
||||
)
|
||||
)
|
||||
|
||||
assert len(w) >= 4
|
||||
warning_texts = [str(warning.message) for warning in w]
|
||||
|
||||
fields = [
|
||||
"checkpoint_dir",
|
||||
"hyperparam_search_config",
|
||||
"logger_config",
|
||||
"TrainingConfig",
|
||||
"DataConfig",
|
||||
"OptimizerConfig",
|
||||
"max_steps_per_epoch",
|
||||
"gradient_accumulation_steps",
|
||||
"max_validation_steps",
|
||||
"dtype",
|
||||
# required unsupported parameters
|
||||
"rank",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
fields = [
|
||||
"checkpoint_dir",
|
||||
"hyperparam_search_config",
|
||||
"logger_config",
|
||||
"TrainingConfig",
|
||||
"DataConfig",
|
||||
"OptimizerConfig",
|
||||
"max_steps_per_epoch",
|
||||
"gradient_accumulation_steps",
|
||||
"max_validation_steps",
|
||||
"dtype",
|
||||
# required unsupported parameters
|
||||
"rank",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
"apply_lora_to_mlp",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
|
@ -18,308 +17,325 @@ from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
|||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
||||
class TestNVIDIASafetyAdapter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
@pytest.fixture
|
||||
def nvidia_safety_adapter():
|
||||
"""Set up the NVIDIA safety adapter with mocked dependencies"""
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
# Initialize the adapter
|
||||
self.config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
)
|
||||
self.adapter = NVIDIASafetyAdapter(config=self.config)
|
||||
self.shield_store = AsyncMock()
|
||||
self.adapter.shield_store = self.shield_store
|
||||
# Initialize the adapter
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
)
|
||||
adapter = NVIDIASafetyAdapter(config=config)
|
||||
shield_store = AsyncMock()
|
||||
adapter.shield_store = shield_store
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.guardrails_post_patcher = patch(
|
||||
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
|
||||
)
|
||||
self.mock_guardrails_post = self.guardrails_post_patcher.start()
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
# Mock the HTTP request methods
|
||||
with patch(
|
||||
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
|
||||
) as mock_guardrails_post:
|
||||
mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
yield adapter, shield_store, mock_guardrails_post
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.guardrails_post_patcher.stop()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
def assert_request(
|
||||
mock_call: MagicMock,
|
||||
expected_url: str,
|
||||
expected_headers: dict[str, str] | None = None,
|
||||
expected_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to verify request details in mock API calls.
|
||||
|
||||
def _assert_request(
|
||||
self,
|
||||
mock_call: MagicMock,
|
||||
expected_url: str,
|
||||
expected_headers: dict[str, str] | None = None,
|
||||
expected_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to verify request details in mock API calls.
|
||||
Args:
|
||||
mock_call: The MagicMock object that was called
|
||||
expected_url: The expected URL to which the request was made
|
||||
expected_headers: Optional dictionary of expected request headers
|
||||
expected_json: Optional dictionary of expected JSON payload
|
||||
"""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
Args:
|
||||
mock_call: The MagicMock object that was called
|
||||
expected_url: The expected URL to which the request was made
|
||||
expected_headers: Optional dictionary of expected request headers
|
||||
expected_json: Optional dictionary of expected JSON payload
|
||||
"""
|
||||
call_args = mock_call.call_args
|
||||
# Check URL
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == expected_url
|
||||
# Check headers if provided
|
||||
if expected_headers:
|
||||
for key, value in expected_headers.items():
|
||||
assert call_args[1]["headers"][key] == value
|
||||
|
||||
# Check headers if provided
|
||||
if expected_headers:
|
||||
for key, value in expected_headers.items():
|
||||
assert call_args[1]["headers"][key] == value
|
||||
# Check JSON if provided
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||
else:
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
# Check JSON if provided
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||
else:
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_register_shield_with_valid_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
async def test_register_shield_with_valid_id(nvidia_safety_adapter):
|
||||
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
|
||||
|
||||
# Register the shield
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
|
||||
def test_register_shield_without_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="",
|
||||
)
|
||||
# Register the shield
|
||||
await adapter.register_shield(shield)
|
||||
|
||||
# Register the shield should raise a ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
|
||||
def test_run_shield_allowed(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
async def test_register_shield_without_id(nvidia_safety_adapter):
|
||||
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="",
|
||||
)
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
# Register the shield should raise a ValueError
|
||||
with pytest.raises(ValueError):
|
||||
await adapter.register_shield(shield)
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
async def test_run_shield_allowed(nvidia_safety_adapter):
|
||||
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
|
||||
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation is None
|
||||
# Verify the result
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation is None
|
||||
|
||||
def test_run_shield_blocked(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
async def test_run_shield_blocked(nvidia_safety_adapter):
|
||||
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
shield_store.get_shield.return_value = shield
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.violation is not None
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation.user_message == "Sorry I cannot do this."
|
||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||
# Verify the result
|
||||
assert result.violation is not None
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation.user_message == "Sorry I cannot do this."
|
||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||
|
||||
def test_run_shield_not_found(self):
|
||||
# Set up shield store to return None
|
||||
shield_id = "non-existent-shield"
|
||||
self.shield_store.get_shield.return_value = None
|
||||
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
]
|
||||
async def test_run_shield_not_found(nvidia_safety_adapter):
|
||||
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
# Set up shield store to return None
|
||||
shield_id = "non-existent-shield"
|
||||
shield_store.get_shield.return_value = None
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
]
|
||||
|
||||
# Verify the Guardrails API was not called
|
||||
self.mock_guardrails_post.assert_not_called()
|
||||
with pytest.raises(ValueError):
|
||||
await adapter.run_shield(shield_id, messages)
|
||||
|
||||
def test_run_shield_http_error(self):
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
# Verify the shield store was called
|
||||
shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Mock Guardrails API to raise an exception
|
||||
error_msg = "API Error: 500 Internal Server Error"
|
||||
self.mock_guardrails_post.side_effect = Exception(error_msg)
|
||||
# Verify the Guardrails API was not called
|
||||
mock_guardrails_post.assert_not_called()
|
||||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
async def test_run_shield_http_error(nvidia_safety_adapter):
|
||||
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API to raise an exception
|
||||
error_msg = "API Error: 500 Internal Server Error"
|
||||
mock_guardrails_post.side_effect = Exception(error_msg)
|
||||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
# Verify the exception message
|
||||
assert error_msg in str(context.exception)
|
||||
},
|
||||
)
|
||||
# Verify the exception message
|
||||
assert error_msg in str(excinfo.value)
|
||||
|
||||
def test_init_nemo_guardrails(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
test_config_id = "test-custom-config-id"
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id=test_config_id,
|
||||
)
|
||||
# Initialize with default parameters
|
||||
test_model = "test-model"
|
||||
guardrails = NeMoGuardrails(config, test_model)
|
||||
def test_init_nemo_guardrails():
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.9 # Default value
|
||||
assert guardrails.temperature == 1.0 # Default value
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
# Initialize with custom parameters
|
||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||
test_config_id = "test-custom-config-id"
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id=test_config_id,
|
||||
)
|
||||
# Initialize with default parameters
|
||||
test_model = "test-model"
|
||||
guardrails = NeMoGuardrails(config, test_model)
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.8
|
||||
assert guardrails.temperature == 0.7
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.9 # Default value
|
||||
assert guardrails.temperature == 1.0 # Default value
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
def test_init_nemo_guardrails_invalid_temperature(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
# Initialize with custom parameters
|
||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id="test-custom-config-id",
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
NeMoGuardrails(config, "test-model", temperature=0)
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.8
|
||||
assert guardrails.temperature == 0.7
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
|
||||
def test_init_nemo_guardrails_invalid_temperature():
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id="test-custom-config-id",
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
NeMoGuardrails(config, "test-model", temperature=0)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -32,331 +31,334 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
|||
)
|
||||
|
||||
|
||||
class TestNvidiaPostTraining(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||
@pytest.fixture
|
||||
def nvidia_adapters():
|
||||
"""Set up the NVIDIA adapters with mocked dependencies"""
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||
|
||||
config = NvidiaPostTrainingConfig(
|
||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||
)
|
||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||
self.make_request_patcher = patch(
|
||||
config = NvidiaPostTrainingConfig(
|
||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||
)
|
||||
adapter = NvidiaPostTrainingAdapter(config)
|
||||
|
||||
# Mock the inference client
|
||||
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||
inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
|
||||
# Mock the inference client
|
||||
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||
|
||||
self.mock_client = unittest.mock.MagicMock()
|
||||
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||
self.inference_make_request_patcher = patch(
|
||||
) as mock_make_request,
|
||||
patch(
|
||||
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
||||
return_value=self.mock_client,
|
||||
return_value=mock_client,
|
||||
),
|
||||
):
|
||||
yield adapter, inference_adapter, mock_make_request, mock_client
|
||||
|
||||
|
||||
def assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||
"""Helper function to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
if expected_method and expected_path:
|
||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||
assert call_args[0] == (expected_method, expected_path)
|
||||
else:
|
||||
assert call_args[1]["method"] == expected_method
|
||||
assert call_args[1]["path"] == expected_path
|
||||
|
||||
if expected_params:
|
||||
assert call_args[1]["params"] == expected_params
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
|
||||
async def test_supervised_fine_tune(nvidia_adapters):
|
||||
"""Test the supervised fine-tuning API call."""
|
||||
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
|
||||
|
||||
mock_make_request.return_value = {
|
||||
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:06:28.542884",
|
||||
"config": {
|
||||
"schema_version": "1.0",
|
||||
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.569837",
|
||||
"custom_fields": {},
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model_path": "llama-3_1-8b-instruct",
|
||||
"training_types": [],
|
||||
"finetuning_types": ["lora"],
|
||||
"precision": "bf16",
|
||||
"num_gpus": 4,
|
||||
"num_nodes": 1,
|
||||
"micro_batch_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
"max_seq_length": 4096,
|
||||
},
|
||||
"dataset": {
|
||||
"schema_version": "1.0",
|
||||
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.542660",
|
||||
"custom_fields": {},
|
||||
"name": "sample-basic-test",
|
||||
"version_id": "main",
|
||||
"version_tags": [],
|
||||
},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "created",
|
||||
"project": "default",
|
||||
"custom_fields": {},
|
||||
"ownership": {"created_by": "me", "access_policies": {}},
|
||||
}
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.simplefilter("always")
|
||||
training_job = await adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
self.inference_make_request_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
self.inference_make_request_patcher.stop()
|
||||
# check the output is a PostTrainingJob
|
||||
assert isinstance(training_job, NvidiaPostTrainingJob)
|
||||
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
if expected_method and expected_path:
|
||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||
assert call_args[0] == (expected_method, expected_path)
|
||||
else:
|
||||
assert call_args[1]["method"] == expected_method
|
||||
assert call_args[1]["path"] == expected_path
|
||||
|
||||
if expected_params:
|
||||
assert call_args[1]["params"] == expected_params
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_supervised_fine_tune(self):
|
||||
"""Test the supervised fine-tuning API call."""
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:06:28.542884",
|
||||
"config": {
|
||||
"schema_version": "1.0",
|
||||
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.569837",
|
||||
"custom_fields": {},
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model_path": "llama-3_1-8b-instruct",
|
||||
"training_types": [],
|
||||
"finetuning_types": ["lora"],
|
||||
"precision": "bf16",
|
||||
"num_gpus": 4,
|
||||
"num_nodes": 1,
|
||||
"micro_batch_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
"max_seq_length": 4096,
|
||||
},
|
||||
"dataset": {
|
||||
"schema_version": "1.0",
|
||||
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.542660",
|
||||
"custom_fields": {},
|
||||
"name": "sample-basic-test",
|
||||
"version_id": "main",
|
||||
"version_tags": [],
|
||||
},
|
||||
mock_make_request.assert_called_once()
|
||||
assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
"/v1/customization/jobs",
|
||||
expected_json={
|
||||
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"finetuning_type": "lora",
|
||||
"epochs": 2,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"weight_decay": 0.01,
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "created",
|
||||
"project": "default",
|
||||
"custom_fields": {},
|
||||
"ownership": {"created_by": "me", "access_policies": {}},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
|
||||
async def test_supervised_fine_tune_with_qat(nvidia_adapters):
|
||||
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
|
||||
|
||||
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
# This will raise NotImplementedError since QAT is not supported
|
||||
with pytest.raises(NotImplementedError):
|
||||
await adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"customizer_status,expected_status",
|
||||
[
|
||||
("running", "in_progress"),
|
||||
("completed", "completed"),
|
||||
("failed", "failed"),
|
||||
("cancelled", "cancelled"),
|
||||
("pending", "scheduled"),
|
||||
("unknown", "scheduled"),
|
||||
],
|
||||
)
|
||||
async def test_get_training_job_status(nvidia_adapters, customizer_status, expected_status):
|
||||
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": customizer_status,
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.simplefilter("always")
|
||||
training_job = self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
# check the output is a PostTrainingJob
|
||||
assert isinstance(training_job, NvidiaPostTrainingJob)
|
||||
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
status = await adapter.get_training_job_status(job_uuid=job_id)
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
"/v1/customization/jobs",
|
||||
expected_json={
|
||||
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||
"hyperparameters": {
|
||||
"training_type": "sft",
|
||||
"finetuning_type": "lora",
|
||||
"epochs": 2,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"weight_decay": 0.01,
|
||||
"lora": {"alpha": 16},
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == expected_status
|
||||
assert status.steps_completed == 1210
|
||||
assert status.epochs_completed == 2
|
||||
assert status.percentage_done == 100.0
|
||||
assert status.best_epoch == 2
|
||||
assert status.train_loss == 1.718016266822815
|
||||
assert status.val_loss == 1.8661999702453613
|
||||
|
||||
assert_request(
|
||||
mock_make_request,
|
||||
"GET",
|
||||
f"/v1/customization/jobs/{job_id}/status",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
|
||||
async def test_get_training_jobs(nvidia_adapters):
|
||||
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
|
||||
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
mock_make_request.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": job_id,
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"config": {
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_supervised_fine_tune_with_qat(self):
|
||||
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
# This will raise NotImplementedError since QAT is not supported
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
def test_get_training_job_status(self):
|
||||
customizer_status_to_job_status = [
|
||||
("running", "in_progress"),
|
||||
("completed", "completed"),
|
||||
("failed", "failed"),
|
||||
("cancelled", "cancelled"),
|
||||
("pending", "scheduled"),
|
||||
("unknown", "scheduled"),
|
||||
"dataset": {"name": "default/sample-basic-test"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "completed",
|
||||
"project": "default",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
for customizer_status, expected_status in customizer_status_to_job_status:
|
||||
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
|
||||
self.mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": customizer_status,
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
jobs = await adapter.get_training_jobs()
|
||||
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
||||
assert len(jobs.data) == 1
|
||||
job = jobs.data[0]
|
||||
assert job.job_uuid == job_id
|
||||
assert job.status.value == "completed"
|
||||
|
||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||
mock_make_request.assert_called_once()
|
||||
assert_request(
|
||||
mock_make_request,
|
||||
"GET",
|
||||
"/v1/customization/jobs",
|
||||
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
|
||||
)
|
||||
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == expected_status
|
||||
assert status.steps_completed == 1210
|
||||
assert status.epochs_completed == 2
|
||||
assert status.percentage_done == 100.0
|
||||
assert status.best_epoch == 2
|
||||
assert status.train_loss == 1.718016266822815
|
||||
assert status.val_loss == 1.8661999702453613
|
||||
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"GET",
|
||||
f"/v1/customization/jobs/{job_id}/status",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
async def test_cancel_training_job(nvidia_adapters):
|
||||
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
|
||||
|
||||
def test_get_training_jobs(self):
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
self.mock_make_request.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": job_id,
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"config": {
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
"dataset": {"name": "default/sample-basic-test"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "completed",
|
||||
"project": "default",
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_make_request.return_value = {} # Empty response for successful cancellation
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
jobs = self.run_async(self.adapter.get_training_jobs())
|
||||
result = await adapter.cancel_training_job(job_uuid=job_id)
|
||||
|
||||
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
||||
assert len(jobs.data) == 1
|
||||
job = jobs.data[0]
|
||||
assert job.job_uuid == job_id
|
||||
assert job.status.value == "completed"
|
||||
assert result is None
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"GET",
|
||||
"/v1/customization/jobs",
|
||||
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
|
||||
mock_make_request.assert_called_once()
|
||||
assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
f"/v1/customization/jobs/{job_id}/cancel",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
|
||||
async def test_inference_register_model(nvidia_adapters):
|
||||
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
|
||||
|
||||
model_id = "default/job-1234"
|
||||
model_type = ModelType.llm
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_id="nvidia",
|
||||
provider_model_id=model_id,
|
||||
provider_resource_id=model_id,
|
||||
model_type=model_type,
|
||||
)
|
||||
result = await inference_adapter.register_model(model)
|
||||
assert result == model
|
||||
assert len(inference_adapter.alias_to_provider_id_map) > 1
|
||||
assert inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||
|
||||
with patch.object(inference_adapter, "chat_completion") as mock_chat_completion:
|
||||
await inference_adapter.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[{"role": "user", "content": "Hello, model"}],
|
||||
)
|
||||
|
||||
def test_cancel_training_job(self):
|
||||
self.mock_make_request.return_value = {} # Empty response for successful cancellation
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
|
||||
|
||||
assert result is None
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
f"/v1/customization/jobs/{job_id}/cancel",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
def test_inference_register_model(self):
|
||||
model_id = "default/job-1234"
|
||||
model_type = ModelType.llm
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_id="nvidia",
|
||||
provider_model_id=model_id,
|
||||
provider_resource_id=model_id,
|
||||
model_type=model_type,
|
||||
)
|
||||
result = self.run_async(self.inference_adapter.register_model(model))
|
||||
assert result == model
|
||||
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||
|
||||
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
||||
self.run_async(
|
||||
self.inference_adapter.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[{"role": "user", "content": "Hello, model"}],
|
||||
)
|
||||
)
|
||||
|
||||
mock_chat_completion.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mock_chat_completion.assert_called()
|
||||
|
|
|
|||
|
|
@ -5,73 +5,91 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
|
||||
|
||||
class TestReplaceEnvVars(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Clear any existing environment variables we'll use in tests
|
||||
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_env_vars():
|
||||
"""Set up test environment variables and clean up after test"""
|
||||
# Store original values
|
||||
original_vars = {}
|
||||
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
||||
original_vars[var] = os.environ.get(var)
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
# Set up test environment variables
|
||||
os.environ["TEST_VAR"] = "test_value"
|
||||
os.environ["EMPTY_VAR"] = ""
|
||||
os.environ["ZERO_VAR"] = "0"
|
||||
# Set up test environment variables
|
||||
os.environ["TEST_VAR"] = "test_value"
|
||||
os.environ["EMPTY_VAR"] = ""
|
||||
os.environ["ZERO_VAR"] = "0"
|
||||
|
||||
def test_simple_replacement(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
|
||||
yield
|
||||
|
||||
def test_default_value_when_not_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.NOT_SET:=default}"), "default")
|
||||
|
||||
def test_default_value_when_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:=default}"), "test_value")
|
||||
|
||||
def test_default_value_when_empty(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=default}"), "default")
|
||||
|
||||
def test_none_value_when_empty(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=}"), None)
|
||||
|
||||
def test_value_when_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:=}"), "test_value")
|
||||
|
||||
def test_empty_var_no_default(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}"), None)
|
||||
|
||||
def test_conditional_value_when_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:+conditional}"), "conditional")
|
||||
|
||||
def test_conditional_value_when_not_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.NOT_SET:+conditional}"), None)
|
||||
|
||||
def test_conditional_value_when_empty(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:+conditional}"), None)
|
||||
|
||||
def test_conditional_value_with_zero(self):
|
||||
self.assertEqual(replace_env_vars("${env.ZERO_VAR:+conditional}"), "conditional")
|
||||
|
||||
def test_mixed_syntax(self):
|
||||
self.assertEqual(
|
||||
replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}"), "test_value and "
|
||||
)
|
||||
self.assertEqual(
|
||||
replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}"), "default and conditional"
|
||||
)
|
||||
|
||||
def test_nested_structures(self):
|
||||
data = {
|
||||
"key1": "${env.TEST_VAR:=default}",
|
||||
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
|
||||
"key3": {"nested": "${env.NOT_SET:+conditional}"},
|
||||
}
|
||||
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
|
||||
self.assertEqual(replace_env_vars(data), expected)
|
||||
# Cleanup: restore original values
|
||||
for var, original_value in original_vars.items():
|
||||
if original_value is not None:
|
||||
os.environ[var] = original_value
|
||||
elif var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
def test_simple_replacement():
|
||||
assert replace_env_vars("${env.TEST_VAR}") == "test_value"
|
||||
|
||||
|
||||
def test_default_value_when_not_set():
|
||||
assert replace_env_vars("${env.NOT_SET:=default}") == "default"
|
||||
|
||||
|
||||
def test_default_value_when_set():
|
||||
assert replace_env_vars("${env.TEST_VAR:=default}") == "test_value"
|
||||
|
||||
|
||||
def test_default_value_when_empty():
|
||||
assert replace_env_vars("${env.EMPTY_VAR:=default}") == "default"
|
||||
|
||||
|
||||
def test_none_value_when_empty():
|
||||
assert replace_env_vars("${env.EMPTY_VAR:=}") is None
|
||||
|
||||
|
||||
def test_value_when_set():
|
||||
assert replace_env_vars("${env.TEST_VAR:=}") == "test_value"
|
||||
|
||||
|
||||
def test_empty_var_no_default():
|
||||
assert replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}") is None
|
||||
|
||||
|
||||
def test_conditional_value_when_set():
|
||||
assert replace_env_vars("${env.TEST_VAR:+conditional}") == "conditional"
|
||||
|
||||
|
||||
def test_conditional_value_when_not_set():
|
||||
assert replace_env_vars("${env.NOT_SET:+conditional}") is None
|
||||
|
||||
|
||||
def test_conditional_value_when_empty():
|
||||
assert replace_env_vars("${env.EMPTY_VAR:+conditional}") is None
|
||||
|
||||
|
||||
def test_conditional_value_with_zero():
|
||||
assert replace_env_vars("${env.ZERO_VAR:+conditional}") == "conditional"
|
||||
|
||||
|
||||
def test_mixed_syntax():
|
||||
assert replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and "
|
||||
assert replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional"
|
||||
|
||||
|
||||
def test_nested_structures():
|
||||
data = {
|
||||
"key1": "${env.TEST_VAR:=default}",
|
||||
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
|
||||
"key3": {"nested": "${env.NOT_SET:+conditional}"},
|
||||
}
|
||||
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
|
||||
assert replace_env_vars(data) == expected
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue