rename impl, add config validation

This commit is contained in:
Ubuntu 2025-03-12 07:41:03 +00:00 committed by raspawar
parent f5ebad130c
commit bd9a6d5f9c
2 changed files with 11 additions and 11 deletions

View file

@ -10,14 +10,18 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import NvidiaPostTrainingConfig from .config import NvidiaPostTrainingConfig
# post_training api and the torchtune provider is still experimental and under heavy development
async def get_adapter_impl( async def get_adapter_impl(
config: NvidiaPostTrainingConfig, config: NvidiaPostTrainingConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, ProviderSpec],
): ):
from .post_training import NvidiaPostTrainingImpl from .post_training import NvidiaPostTrainingAdapter
impl = NvidiaPostTrainingImpl(config) if not isinstance(config, NvidiaPostTrainingConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
impl = NvidiaPostTrainingAdapter(config)
return impl return impl
__all__ = ["get_adapter_impl", "NvidiaPostTrainingAdapter"]

View file

@ -50,7 +50,7 @@ class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob] data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingImpl: class NvidiaPostTrainingAdapter:
def __init__(self, config: NvidiaPostTrainingConfig): def __init__(self, config: NvidiaPostTrainingConfig):
self.config = config self.config = config
self.headers = {} self.headers = {}
@ -226,12 +226,8 @@ class NvidiaPostTrainingImpl:
# Extract LoRA-specific parameters # Extract LoRA-specific parameters
lora_config = {k: v for k, v in algorithm_config.items() if k != "type"} lora_config = {k: v for k, v in algorithm_config.items() if k != "type"}
job_config["hyperparameters"]["lora"] = lora_config job_config["hyperparameters"]["lora"] = lora_config
else:
# Add adapter_dim if available in training_config raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
if training_config.get("algorithm_config", {}).get("adapter_dim"):
job_config["hyperparameters"]["lora"]["adapter_dim"] = training_config["algorithm_config"][
"adapter_dim"
]
# Create the customization job # Create the customization job
response = await self._make_request( response = await self._make_request(