rename impl, add config validation

This commit is contained in:
Ubuntu 2025-03-12 07:41:03 +00:00 committed by raspawar
parent f5ebad130c
commit bd9a6d5f9c
2 changed files with 11 additions and 11 deletions

View file

@ -10,14 +10,18 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import NvidiaPostTrainingConfig
# post_training api and the torchtune provider is still experimental and under heavy development
async def get_adapter_impl(
config: NvidiaPostTrainingConfig,
deps: Dict[Api, ProviderSpec],
):
from .post_training import NvidiaPostTrainingImpl
from .post_training import NvidiaPostTrainingAdapter
impl = NvidiaPostTrainingImpl(config)
if not isinstance(config, NvidiaPostTrainingConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
impl = NvidiaPostTrainingAdapter(config)
return impl
__all__ = ["get_adapter_impl", "NvidiaPostTrainingAdapter"]

View file

@ -50,7 +50,7 @@ class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingImpl:
class NvidiaPostTrainingAdapter:
def __init__(self, config: NvidiaPostTrainingConfig):
self.config = config
self.headers = {}
@ -226,12 +226,8 @@ class NvidiaPostTrainingImpl:
# Extract LoRA-specific parameters
lora_config = {k: v for k, v in algorithm_config.items() if k != "type"}
job_config["hyperparameters"]["lora"] = lora_config
# Add adapter_dim if available in training_config
if training_config.get("algorithm_config", {}).get("adapter_dim"):
job_config["hyperparameters"]["lora"]["adapter_dim"] = training_config["algorithm_config"][
"adapter_dim"
]
else:
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
# Create the customization job
response = await self._make_request(