mock llamastackclient in unit tests

This commit is contained in:
raspawar 2025-03-24 15:14:17 +05:30
parent 0d4dc06a3c
commit 34db80fb15
4 changed files with 294 additions and 44 deletions

View file

@ -7,7 +7,7 @@
import os
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
import warnings
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
from llama_stack_client.types.post_training.job_status_response import JobStatusResponse
@ -19,6 +19,7 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
)
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from .mock_llama_stack_client import MockLlamaStackClient
class TestNvidiaPostTraining(unittest.TestCase):
@ -28,10 +29,9 @@ class TestNvidiaPostTraining(unittest.TestCase):
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002" # mocking llama stack base url
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
_ = self.llama_stack_client.initialize()
## ToDo: post health checks for customizer are enabled, include test cases for NVIDIA_CUSTOMIZER_URL
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
_ = self.llama_stack_client.initialize()
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
@ -125,15 +125,29 @@ class TestNvidiaPostTraining(unittest.TestCase):
optimizer_config=optimizer_config,
)
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
# required lora config unsupported parameters warnings
fields = [
"apply_lora_to_mlp",
"rank",
"use_dora",
"lora_attn_modules",
"quantize_base",
"apply_lora_to_output",
]
for field in fields:
assert any(field in str(warning.message) for warning in w)
# check the output is a PostTrainingJob
# Note: Although the type is PostTrainingJob: llama_stack.apis.post_training.PostTrainingJob,
@ -149,7 +163,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
"/v1/customization/jobs",
expected_json={
"config": "meta/llama-3.1-8b-instruct",
"dataset": {"name": "sample-basic-test", "namespace": ""},
"dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",