mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
fix: Pass model param as configuration name to NeMo Customizer
This commit is contained in:
parent
ed7b4731aa
commit
1d94f3617a
2 changed files with 5 additions and 8 deletions
|
@ -224,7 +224,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
training_config: TrainingConfig - Configuration for training
|
training_config: TrainingConfig - Configuration for training
|
||||||
model: str - Model identifier
|
model: str - NeMo Customizer configuration name
|
||||||
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
|
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
|
||||||
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
|
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
|
||||||
job_uuid: str - Unique identifier for the job, ignored atm
|
job_uuid: str - Unique identifier for the job, ignored atm
|
||||||
|
@ -299,9 +299,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
User is informed about unsupported parameters via warnings.
|
User is informed about unsupported parameters via warnings.
|
||||||
"""
|
"""
|
||||||
# Map model to nvidia model name
|
|
||||||
# See `_MODEL_ENTRIES` for supported models
|
|
||||||
nvidia_model = self.get_provider_model_id(model)
|
|
||||||
|
|
||||||
# Check for unsupported method parameters
|
# Check for unsupported method parameters
|
||||||
unsupported_method_params = []
|
unsupported_method_params = []
|
||||||
|
@ -347,7 +344,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
# Prepare base job configuration
|
# Prepare base job configuration
|
||||||
job_config = {
|
job_config = {
|
||||||
"config": nvidia_model,
|
"config": model,
|
||||||
"dataset": {
|
"dataset": {
|
||||||
"name": training_config["data_config"]["dataset_id"],
|
"name": training_config["data_config"]["dataset_id"],
|
||||||
"namespace": self.config.dataset_namespace,
|
"namespace": self.config.dataset_namespace,
|
||||||
|
|
|
@ -165,7 +165,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
training_job = self.run_async(
|
training_job = self.run_async(
|
||||||
self.adapter.supervised_fine_tune(
|
self.adapter.supervised_fine_tune(
|
||||||
job_uuid="1234",
|
job_uuid="1234",
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=convert_pydantic_to_json_value(training_config),
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
|
@ -184,7 +184,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
"POST",
|
"POST",
|
||||||
"/v1/customization/jobs",
|
"/v1/customization/jobs",
|
||||||
expected_json={
|
expected_json={
|
||||||
"config": "meta/llama-3.1-8b-instruct",
|
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||||
"hyperparameters": {
|
"hyperparameters": {
|
||||||
"training_type": "sft",
|
"training_type": "sft",
|
||||||
|
@ -219,7 +219,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
self.run_async(
|
self.run_async(
|
||||||
self.adapter.supervised_fine_tune(
|
self.adapter.supervised_fine_tune(
|
||||||
job_uuid="1234",
|
job_uuid="1234",
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=convert_pydantic_to_json_value(training_config),
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue