mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 08:52:16 +00:00
mock llamastackclient in unit tests
This commit is contained in:
parent
0d4dc06a3c
commit
34db80fb15
4 changed files with 294 additions and 44 deletions
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import warnings
|
||||
import pytest
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
||||
from llama_stack_client.types.post_training.job_status_response import JobStatusResponse
|
||||
|
|
@ -19,6 +19,7 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
|||
)
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
from .mock_llama_stack_client import MockLlamaStackClient
|
||||
|
||||
|
||||
class TestNvidiaPostTraining(unittest.TestCase):
|
||||
|
|
@ -28,10 +29,9 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002" # mocking llama stack base url
|
||||
|
||||
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||
_ = self.llama_stack_client.initialize()
|
||||
|
||||
## ToDo: post health checks for customizer are enabled, include test cases for NVIDIA_CUSTOMIZER_URL
|
||||
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
|
||||
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||
_ = self.llama_stack_client.initialize()
|
||||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
|
|
@ -125,15 +125,29 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
# required lora config unsupported parameters warnings
|
||||
fields = [
|
||||
"apply_lora_to_mlp",
|
||||
"rank",
|
||||
"use_dora",
|
||||
"lora_attn_modules",
|
||||
"quantize_base",
|
||||
"apply_lora_to_output",
|
||||
]
|
||||
for field in fields:
|
||||
assert any(field in str(warning.message) for warning in w)
|
||||
|
||||
# check the output is a PostTrainingJob
|
||||
# Note: Although the type is PostTrainingJob: llama_stack.apis.post_training.PostTrainingJob,
|
||||
|
|
@ -149,7 +163,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
"/v1/customization/jobs",
|
||||
expected_json={
|
||||
"config": "meta/llama-3.1-8b-instruct",
|
||||
"dataset": {"name": "sample-basic-test", "namespace": ""},
|
||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||
"hyperparameters": {
|
||||
"training_type": "sft",
|
||||
"finetuning_type": "lora",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue