mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
mock llamastackclient in unit tests
This commit is contained in:
parent
0d4dc06a3c
commit
34db80fb15
4 changed files with 294 additions and 44 deletions
20
tests/unit/providers/nvidia/conftest.py
Normal file
20
tests/unit/providers/nvidia/conftest.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.closed = False
|
||||||
|
mock_session.close = AsyncMock()
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def patch_aiohttp_session():
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
yield
|
147
tests/unit/providers/nvidia/mock_llama_stack_client.py
Normal file
147
tests/unit/providers/nvidia/mock_llama_stack_client.py
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack_client.types.algorithm_config_param import QatFinetuningConfig
|
||||||
|
from llama_stack_client.types.post_training.job_status_response import JobStatusResponse
|
||||||
|
from llama_stack_client.types.post_training_job import PostTrainingJob
|
||||||
|
|
||||||
|
|
||||||
|
class MockLlamaStackClient:
|
||||||
|
"""Mock client for testing NVIDIA post-training functionality."""
|
||||||
|
|
||||||
|
def __init__(self, provider="nvidia"):
|
||||||
|
self.provider = provider
|
||||||
|
self.post_training = MockPostTraining()
|
||||||
|
self.inference = MockInference()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Mock initialization method."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close any open resources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockPostTraining:
|
||||||
|
"""Mock post-training module."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.job = MockPostTrainingJob()
|
||||||
|
|
||||||
|
def supervised_fine_tune(
|
||||||
|
self,
|
||||||
|
job_uuid,
|
||||||
|
model,
|
||||||
|
checkpoint_dir,
|
||||||
|
algorithm_config,
|
||||||
|
training_config,
|
||||||
|
logger_config,
|
||||||
|
hyperparam_search_config,
|
||||||
|
):
|
||||||
|
"""Mock supervised fine-tuning method."""
|
||||||
|
if isinstance(algorithm_config, QatFinetuningConfig):
|
||||||
|
raise NotImplementedError("QAT fine-tuning is not supported by NVIDIA provider")
|
||||||
|
|
||||||
|
# Return a mock PostTrainingJob
|
||||||
|
return PostTrainingJob(
|
||||||
|
job_uuid="cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
status="created",
|
||||||
|
created_at="2024-12-09T04:06:28.542884",
|
||||||
|
updated_at="2024-12-09T04:06:28.542884",
|
||||||
|
model=model,
|
||||||
|
dataset_id=training_config.data_config.dataset_id,
|
||||||
|
output_model="default/job-1234",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def supervised_fine_tune_async(
|
||||||
|
self,
|
||||||
|
job_uuid,
|
||||||
|
model,
|
||||||
|
checkpoint_dir,
|
||||||
|
algorithm_config,
|
||||||
|
training_config,
|
||||||
|
logger_config,
|
||||||
|
hyperparam_search_config,
|
||||||
|
):
|
||||||
|
"""Mock async supervised fine-tuning method."""
|
||||||
|
if isinstance(algorithm_config, QatFinetuningConfig):
|
||||||
|
raise NotImplementedError("QAT fine-tuning is not supported by NVIDIA provider")
|
||||||
|
|
||||||
|
# Return a mock response dictionary
|
||||||
|
return {
|
||||||
|
"job_uuid": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
"status": "created",
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"model": model,
|
||||||
|
"dataset_id": training_config.data_config.dataset_id,
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MockPostTrainingJob:
|
||||||
|
"""Mock post-training job module."""
|
||||||
|
|
||||||
|
def status(self, job_uuid):
|
||||||
|
"""Mock job status method."""
|
||||||
|
return JobStatusResponse(
|
||||||
|
status="completed",
|
||||||
|
steps_completed=1210,
|
||||||
|
epochs_completed=2,
|
||||||
|
percentage_done=100.0,
|
||||||
|
best_epoch=2,
|
||||||
|
train_loss=1.718016266822815,
|
||||||
|
val_loss=1.8661999702453613,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list(self):
|
||||||
|
"""Mock job list method."""
|
||||||
|
return [
|
||||||
|
PostTrainingJob(
|
||||||
|
job_uuid="cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||||
|
status="completed",
|
||||||
|
created_at="2024-12-09T04:06:28.542884",
|
||||||
|
updated_at="2024-12-09T04:21:19.852832",
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
dataset_id="sample-basic-test",
|
||||||
|
output_model="default/job-1234",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def cancel(self, job_uuid):
|
||||||
|
"""Mock job cancel method."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MockInference:
|
||||||
|
"""Mock inference module."""
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
content,
|
||||||
|
stream=False,
|
||||||
|
model_id=None,
|
||||||
|
sampling_params=None,
|
||||||
|
):
|
||||||
|
"""Mock completion method."""
|
||||||
|
return {
|
||||||
|
"id": "cmpl-123456",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1677858242,
|
||||||
|
"model": model_id,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "The next GTC will take place in the middle of March, 2023.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 100, "completion_tokens": 12, "total_tokens": 112},
|
||||||
|
}
|
|
@ -7,16 +7,30 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch, AsyncMock, MagicMock
|
||||||
|
import pytest
|
||||||
|
import atexit
|
||||||
|
|
||||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
||||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
TrainingConfigDataConfig,
|
TrainingConfigDataConfig,
|
||||||
TrainingConfigEfficiencyConfig,
|
|
||||||
TrainingConfigOptimizerConfig,
|
TrainingConfigOptimizerConfig,
|
||||||
|
TrainingConfigEfficiencyConfig,
|
||||||
)
|
)
|
||||||
|
from .mock_llama_stack_client import MockLlamaStackClient
|
||||||
|
|
||||||
|
# Create a mock session
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.closed = False
|
||||||
|
mock_session.close = AsyncMock()
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock()
|
||||||
|
|
||||||
|
patch("aiohttp.ClientSession", return_value=mock_session).start()
|
||||||
|
|
||||||
|
atexit.register(lambda: patch.stopall())
|
||||||
|
|
||||||
|
|
||||||
class TestNvidiaParameters(unittest.TestCase):
|
class TestNvidiaParameters(unittest.TestCase):
|
||||||
|
@ -25,8 +39,10 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||||
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
|
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
|
||||||
|
|
||||||
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
# Use the mock client
|
||||||
_ = self.llama_stack_client.initialize()
|
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
|
||||||
|
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
_ = self.llama_stack_client.initialize()
|
||||||
|
|
||||||
self.make_request_patcher = patch(
|
self.make_request_patcher = patch(
|
||||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||||
|
@ -40,6 +56,10 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
# Close the client if it has a close method
|
||||||
|
if hasattr(self.llama_stack_client, "close"):
|
||||||
|
self.llama_stack_client.close()
|
||||||
|
|
||||||
self.make_request_patcher.stop()
|
self.make_request_patcher.stop()
|
||||||
|
|
||||||
def _assert_request_params(self, expected_json):
|
def _assert_request_params(self, expected_json):
|
||||||
|
@ -54,7 +74,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
else:
|
else:
|
||||||
assert actual_json[key] == value
|
assert actual_json[key] == value
|
||||||
|
|
||||||
def test_optional_parameter_passed(self):
|
def test_customizer_parameters_passed(self):
|
||||||
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
||||||
custom_adapter_dim = 32 # Different from default of 8
|
custom_adapter_dim = 32 # Different from default of 8
|
||||||
algorithm_config = LoraFinetuningConfig(
|
algorithm_config = LoraFinetuningConfig(
|
||||||
|
@ -76,15 +96,27 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
optimizer_config=optimizer_config,
|
optimizer_config=optimizer_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llama_stack_client.post_training.supervised_fine_tune(
|
with warnings.catch_warnings(record=True) as w:
|
||||||
job_uuid="test-job",
|
warnings.simplefilter("always")
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
self.llama_stack_client.post_training.supervised_fine_tune(
|
||||||
checkpoint_dir="",
|
job_uuid="test-job",
|
||||||
algorithm_config=algorithm_config,
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
training_config=training_config,
|
checkpoint_dir="",
|
||||||
logger_config={},
|
algorithm_config=algorithm_config,
|
||||||
hyperparam_search_config={},
|
training_config=training_config,
|
||||||
)
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
warning_texts = [str(warning.message) for warning in w]
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
"apply_lora_to_output",
|
||||||
|
"lora_attn_modules",
|
||||||
|
"apply_lora_to_mlp",
|
||||||
|
]
|
||||||
|
for field in fields:
|
||||||
|
assert any(field in text for text in warning_texts)
|
||||||
|
|
||||||
self._assert_request_params(
|
self._assert_request_params(
|
||||||
{
|
{
|
||||||
|
@ -97,7 +129,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_required_parameter_passed(self):
|
def test_required_parameters_passed(self):
|
||||||
"""Test scenario 2: When required parameters are passed."""
|
"""Test scenario 2: When required parameters are passed."""
|
||||||
required_model = "meta-llama/Llama-3.1-8B-Instruct"
|
required_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
required_dataset_id = "required-dataset"
|
required_dataset_id = "required-dataset"
|
||||||
|
@ -127,25 +159,40 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
optimizer_config=optimizer_config,
|
optimizer_config=optimizer_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llama_stack_client.post_training.supervised_fine_tune(
|
# catch required unsupported parameters warnings
|
||||||
job_uuid=required_job_uuid, # Required parameter
|
with warnings.catch_warnings(record=True) as w:
|
||||||
model=required_model, # Required parameter
|
warnings.simplefilter("always")
|
||||||
checkpoint_dir="",
|
self.llama_stack_client.post_training.supervised_fine_tune(
|
||||||
algorithm_config=algorithm_config,
|
job_uuid=required_job_uuid, # Required parameter
|
||||||
training_config=training_config,
|
model=required_model, # Required parameter
|
||||||
logger_config={},
|
checkpoint_dir="",
|
||||||
hyperparam_search_config={},
|
algorithm_config=algorithm_config,
|
||||||
)
|
training_config=training_config,
|
||||||
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
self.mock_make_request.assert_called_once()
|
||||||
call_args = self.mock_make_request.call_args
|
call_args = self.mock_make_request.call_args
|
||||||
|
|
||||||
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
|
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
|
||||||
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
||||||
|
|
||||||
|
warning_texts = [str(warning.message) for warning in w]
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
"rank",
|
||||||
|
"use_dora",
|
||||||
|
"quantize_base",
|
||||||
|
"apply_lora_to_output",
|
||||||
|
"lora_attn_modules",
|
||||||
|
"apply_lora_to_mlp",
|
||||||
|
]
|
||||||
|
for field in fields:
|
||||||
|
assert any(field in text for text in warning_texts)
|
||||||
|
|
||||||
def test_unsupported_parameters_warning(self):
|
def test_unsupported_parameters_warning(self):
|
||||||
"""Test that warnings are raised for unsupported parameters."""
|
"""Test that warnings are raised for unsupported parameters."""
|
||||||
# Create a training config with unsupported parameters
|
|
||||||
data_config = TrainingConfigDataConfig(
|
data_config = TrainingConfigDataConfig(
|
||||||
dataset_id="test-dataset",
|
dataset_id="test-dataset",
|
||||||
batch_size=8,
|
batch_size=8,
|
||||||
|
@ -215,10 +262,32 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
"max_validation_steps",
|
"max_validation_steps",
|
||||||
"dtype",
|
"dtype",
|
||||||
|
# required unsupported parameters
|
||||||
|
"rank",
|
||||||
|
"use_dora",
|
||||||
|
"quantize_base",
|
||||||
|
"apply_lora_to_output",
|
||||||
|
"lora_attn_modules",
|
||||||
|
"apply_lora_to_mlp",
|
||||||
]
|
]
|
||||||
for field in fields:
|
for field in fields:
|
||||||
assert any(field in text for text in warning_texts)
|
assert any(field in text for text in warning_texts)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llama_stack_client():
|
||||||
|
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
|
||||||
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||||
|
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002"
|
||||||
|
|
||||||
|
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
|
||||||
|
client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
_ = client.initialize()
|
||||||
|
yield client
|
||||||
|
|
||||||
|
if hasattr(client, "close"):
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
import warnings
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
||||||
from llama_stack_client.types.post_training.job_status_response import JobStatusResponse
|
from llama_stack_client.types.post_training.job_status_response import JobStatusResponse
|
||||||
|
@ -19,6 +19,7 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
|
from .mock_llama_stack_client import MockLlamaStackClient
|
||||||
|
|
||||||
|
|
||||||
class TestNvidiaPostTraining(unittest.TestCase):
|
class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
|
@ -28,10 +29,9 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||||
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002" # mocking llama stack base url
|
os.environ["LLAMA_STACK_BASE_URL"] = "http://localhost:5002" # mocking llama stack base url
|
||||||
|
|
||||||
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
with patch("llama_stack.distribution.library_client.LlamaStackAsLibraryClient", MockLlamaStackClient):
|
||||||
_ = self.llama_stack_client.initialize()
|
self.llama_stack_client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
_ = self.llama_stack_client.initialize()
|
||||||
## ToDo: post health checks for customizer are enabled, include test cases for NVIDIA_CUSTOMIZER_URL
|
|
||||||
|
|
||||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||||
"""Helper method to verify request details in mock calls."""
|
"""Helper method to verify request details in mock calls."""
|
||||||
|
@ -125,15 +125,29 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
optimizer_config=optimizer_config,
|
optimizer_config=optimizer_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
|
with warnings.catch_warnings(record=True) as w:
|
||||||
job_uuid="1234",
|
warnings.simplefilter("always")
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
training_job = self.llama_stack_client.post_training.supervised_fine_tune(
|
||||||
checkpoint_dir="",
|
job_uuid="1234",
|
||||||
algorithm_config=algorithm_config,
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
training_config=training_config,
|
checkpoint_dir="",
|
||||||
logger_config={},
|
algorithm_config=algorithm_config,
|
||||||
hyperparam_search_config={},
|
training_config=training_config,
|
||||||
)
|
logger_config={},
|
||||||
|
hyperparam_search_config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# required lora config unsupported parameters warnings
|
||||||
|
fields = [
|
||||||
|
"apply_lora_to_mlp",
|
||||||
|
"rank",
|
||||||
|
"use_dora",
|
||||||
|
"lora_attn_modules",
|
||||||
|
"quantize_base",
|
||||||
|
"apply_lora_to_output",
|
||||||
|
]
|
||||||
|
for field in fields:
|
||||||
|
assert any(field in str(warning.message) for warning in w)
|
||||||
|
|
||||||
# check the output is a PostTrainingJob
|
# check the output is a PostTrainingJob
|
||||||
# Note: Although the type is PostTrainingJob: llama_stack.apis.post_training.PostTrainingJob,
|
# Note: Although the type is PostTrainingJob: llama_stack.apis.post_training.PostTrainingJob,
|
||||||
|
@ -149,7 +163,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
"/v1/customization/jobs",
|
"/v1/customization/jobs",
|
||||||
expected_json={
|
expected_json={
|
||||||
"config": "meta/llama-3.1-8b-instruct",
|
"config": "meta/llama-3.1-8b-instruct",
|
||||||
"dataset": {"name": "sample-basic-test", "namespace": ""},
|
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||||
"hyperparameters": {
|
"hyperparameters": {
|
||||||
"training_type": "sft",
|
"training_type": "sft",
|
||||||
"finetuning_type": "lora",
|
"finetuning_type": "lora",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue