fix: Pass model parameter as config name to NeMo Customizer (#2218)

# What does this PR do?
When launching a fine-tuning job, an upcoming version of NeMo Customizer
will expect the `config` name to be formatted as
`namespace/name@version`. Here, `config` is a reference to a model +
additional metadata. There could be multiple `config`s that reference
the same base model.

This PR updates NVIDIA's `supervised_fine_tune` to simply pass the
`model` param as-is to NeMo Customizer. Currently, it expects a
specific, allowlisted llama model (i.e. `meta/Llama3.1-8B-Instruct`) and
converts it to the provider format (`meta/llama-3.1-8b-instruct`).

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
From a notebook, I built an image with my changes: 
```
!llama stack build --template nvidia --image-type venv
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient

client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
And could successfully launch a job:
```
response = client.post_training.supervised_fine_tune(
    job_uuid="",
    model="meta/llama-3.2-1b-instruct@v1.0.0+A100", # Model passed as-is to Customimzer
    ...
)

job_id = response.job_uuid
print(f"Created job with ID: {job_id}")

Output:
Created job with ID: cust-Jm4oGmbwcvoufaLU4XkrRU
```

[//]: # (## Documentation)

---------

Co-authored-by: Jash Gulabrai <jgulabrai@nvidia.com>
This commit is contained in:
Jash Gulabrai 2025-05-20 12:51:39 -04:00 committed by GitHub
parent 2eae8568e1
commit 1a770cf8ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 7 additions and 10 deletions

View file

@ -224,7 +224,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
Parameters: Parameters:
training_config: TrainingConfig - Configuration for training training_config: TrainingConfig - Configuration for training
model: str - Model identifier model: str - NeMo Customizer configuration name
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
job_uuid: str - Unique identifier for the job, ignored atm job_uuid: str - Unique identifier for the job, ignored atm
@ -299,9 +299,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
User is informed about unsupported parameters via warnings. User is informed about unsupported parameters via warnings.
""" """
# Map model to nvidia model name
# See `_MODEL_ENTRIES` for supported models
nvidia_model = self.get_provider_model_id(model)
# Check for unsupported method parameters # Check for unsupported method parameters
unsupported_method_params = [] unsupported_method_params = []
@ -347,7 +344,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
# Prepare base job configuration # Prepare base job configuration
job_config = { job_config = {
"config": nvidia_model, "config": model,
"dataset": { "dataset": {
"name": training_config["data_config"]["dataset_id"], "name": training_config["data_config"]["dataset_id"],
"namespace": self.config.dataset_namespace, "namespace": self.config.dataset_namespace,

View file

@ -131,7 +131,7 @@ class TestNvidiaParameters(unittest.TestCase):
def test_required_parameters_passed(self): def test_required_parameters_passed(self):
"""Test scenario 2: When required parameters are passed.""" """Test scenario 2: When required parameters are passed."""
required_model = "meta-llama/Llama-3.1-8B-Instruct" required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40"
required_dataset_id = "required-dataset" required_dataset_id = "required-dataset"
required_job_uuid = "required-job" required_job_uuid = "required-job"
@ -190,7 +190,7 @@ class TestNvidiaParameters(unittest.TestCase):
self.mock_make_request.assert_called_once() self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args call_args = self.mock_make_request.call_args
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct" assert call_args[1]["json"]["config"] == required_model
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
def test_unsupported_parameters_warning(self): def test_unsupported_parameters_warning(self):

View file

@ -165,7 +165,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
training_job = self.run_async( training_job = self.run_async(
self.adapter.supervised_fine_tune( self.adapter.supervised_fine_tune(
job_uuid="1234", job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct", model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="", checkpoint_dir="",
algorithm_config=algorithm_config, algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config), training_config=convert_pydantic_to_json_value(training_config),
@ -184,7 +184,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
"POST", "POST",
"/v1/customization/jobs", "/v1/customization/jobs",
expected_json={ expected_json={
"config": "meta/llama-3.1-8b-instruct", "config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
"dataset": {"name": "sample-basic-test", "namespace": "default"}, "dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": { "hyperparameters": {
"training_type": "sft", "training_type": "sft",
@ -219,7 +219,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
self.run_async( self.run_async(
self.adapter.supervised_fine_tune( self.adapter.supervised_fine_tune(
job_uuid="1234", job_uuid="1234",
model="meta-llama/Llama-3.1-8B-Instruct", model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="", checkpoint_dir="",
algorithm_config=algorithm_config, algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config), training_config=convert_pydantic_to_json_value(training_config),