add model mapping

This commit is contained in:
Ubuntu 2025-03-12 14:32:08 +00:00 committed by raspawar
parent bd9b6a6e00
commit d7ead08cb9
4 changed files with 37 additions and 6 deletions

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
_MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
)
]
def get_model_entries() -> List[ProviderModelEntry]:
return _MODEL_ENTRIES

View file

@ -22,8 +22,11 @@ from llama_stack.apis.post_training import (
from llama_stack.providers.remote.post_training.nvidia.config import (
NvidiaPostTrainingConfig,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.schema_utils import webmethod
from .models import _MODEL_ENTRIES
# Map API status to JobStatus enum
STATUS_MAPPING = {
"running": "in_progress",
@ -51,7 +54,7 @@ class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingAdapter:
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
def __init__(self, config: NvidiaPostTrainingConfig):
self.config = config
self.headers = {}
@ -59,6 +62,8 @@ class NvidiaPostTrainingAdapter:
self.headers["Authorization"] = f"Bearer {config.api_key}"
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
async def _make_request(
self,
@ -200,11 +205,7 @@ class NvidiaPostTrainingAdapter:
- NVIDIA_OUTPUT_MODEL_DIR: Directory to save the output model
"""
# map model to nvidia model name
model_mapping = {
"Llama3.1-8B-Instruct": "meta/llama-3.1-8b-instruct",
"meta-llama/Llama-3.1-8B-Instruct": "meta/llama-3.1-8b-instruct",
}
nvidia_model = model_mapping.get(model, model)
nvidia_model = self.get_provider_model_id(model)
# Check for unsupported parameters
if checkpoint_dir or hyperparam_search_config or logger_config:

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -22,6 +22,7 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
class TestNvidiaPostTraining(unittest.TestCase):
# ToDo: add tests for env variables, models supported.
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer