mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: Pass model parameter as config name to NeMo Customizer (#2218)
# What does this PR do? When launching a fine-tuning job, an upcoming version of NeMo Customizer will expect the `config` name to be formatted as `namespace/name@version`. Here, `config` is a reference to a model + additional metadata. There could be multiple `config`s that reference the same base model. This PR updates NVIDIA's `supervised_fine_tune` to simply pass the `model` param as-is to NeMo Customizer. Currently, it expects a specific, allowlisted llama model (i.e. `meta/Llama3.1-8B-Instruct`) and converts it to the provider format (`meta/llama-3.1-8b-instruct`). [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan From a notebook, I built an image with my changes: ``` !llama stack build --template nvidia --image-type venv from llama_stack.distribution.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") client.initialize() ``` And could successfully launch a job: ``` response = client.post_training.supervised_fine_tune( job_uuid="", model="meta/llama-3.2-1b-instruct@v1.0.0+A100", # Model passed as-is to Customimzer ... ) job_id = response.job_uuid print(f"Created job with ID: {job_id}") Output: Created job with ID: cust-Jm4oGmbwcvoufaLU4XkrRU ``` [//]: # (## Documentation) --------- Co-authored-by: Jash Gulabrai <jgulabrai@nvidia.com>
This commit is contained in:
parent
2eae8568e1
commit
1a770cf8ac
3 changed files with 7 additions and 10 deletions
|
@ -131,7 +131,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
|
||||
def test_required_parameters_passed(self):
|
||||
"""Test scenario 2: When required parameters are passed."""
|
||||
required_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40"
|
||||
required_dataset_id = "required-dataset"
|
||||
required_job_uuid = "required-job"
|
||||
|
||||
|
@ -190,7 +190,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
self.mock_make_request.assert_called_once()
|
||||
call_args = self.mock_make_request.call_args
|
||||
|
||||
assert call_args[1]["json"]["config"] == "meta/llama-3.1-8b-instruct"
|
||||
assert call_args[1]["json"]["config"] == required_model
|
||||
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
|
||||
|
||||
def test_unsupported_parameters_warning(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue