diff --git a/llama_stack/providers/remote/post_training/nvidia/models.py b/llama_stack/providers/remote/post_training/nvidia/models.py new file mode 100644 index 000000000..04a9af38c --- /dev/null +++ b/llama_stack/providers/remote/post_training/nvidia/models.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, + build_hf_repo_model_entry, +) + +_MODEL_ENTRIES = [ + build_hf_repo_model_entry( + "meta/llama-3.1-8b-instruct", + CoreModelId.llama3_1_8b_instruct.value, + ) +] + + +def get_model_entries() -> List[ProviderModelEntry]: + return _MODEL_ENTRIES diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index 9af480838..1c517c774 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -22,8 +22,11 @@ from llama_stack.apis.post_training import ( from llama_stack.providers.remote.post_training.nvidia.config import ( NvidiaPostTrainingConfig, ) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.schema_utils import webmethod +from .models import _MODEL_ENTRIES + # Map API status to JobStatus enum STATUS_MAPPING = { "running": "in_progress", @@ -51,7 +54,7 @@ class ListNvidiaPostTrainingJobs(BaseModel): data: List[NvidiaPostTrainingJob] -class NvidiaPostTrainingAdapter: +class NvidiaPostTrainingAdapter(ModelRegistryHelper): def __init__(self, config: NvidiaPostTrainingConfig): self.config = config self.headers = {} @@ -59,6 +62,8 @@ class NvidiaPostTrainingAdapter: self.headers["Authorization"] = f"Bearer {config.api_key}" self.timeout = aiohttp.ClientTimeout(total=config.timeout) + # TODO(mf): filter by available models + ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES) async def _make_request( self, @@ -200,11 +205,7 @@ class NvidiaPostTrainingAdapter: - NVIDIA_OUTPUT_MODEL_DIR: Directory to save the output model """ # map model to nvidia model name - model_mapping = { - "Llama3.1-8B-Instruct": "meta/llama-3.1-8b-instruct", - "meta-llama/Llama-3.1-8B-Instruct": "meta/llama-3.1-8b-instruct", - } - nvidia_model = model_mapping.get(model, model) + nvidia_model = self.get_provider_model_id(model) # Check for unsupported parameters if checkpoint_dir or hyperparam_search_config or logger_config: diff --git a/tests/unit/providers/nvidia/__init__.py b/tests/unit/providers/nvidia/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/nvidia/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 0c5b11561..dfdca39d1 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -22,6 +22,7 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient class TestNvidiaPostTraining(unittest.TestCase): + # ToDo: add tests for env variables, models supported. def setUp(self): os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer