mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
rename impl, add config validation
This commit is contained in:
parent
f5ebad130c
commit
bd9a6d5f9c
2 changed files with 11 additions and 11 deletions
|
@ -10,14 +10,18 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec
|
|||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
# post_training api and the torchtune provider is still experimental and under heavy development
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: NvidiaPostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
from .post_training import NvidiaPostTrainingImpl
|
||||
from .post_training import NvidiaPostTrainingAdapter
|
||||
|
||||
impl = NvidiaPostTrainingImpl(config)
|
||||
if not isinstance(config, NvidiaPostTrainingConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
|
||||
impl = NvidiaPostTrainingAdapter(config)
|
||||
return impl
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "NvidiaPostTrainingAdapter"]
|
||||
|
|
|
@ -50,7 +50,7 @@ class ListNvidiaPostTrainingJobs(BaseModel):
|
|||
data: List[NvidiaPostTrainingJob]
|
||||
|
||||
|
||||
class NvidiaPostTrainingImpl:
|
||||
class NvidiaPostTrainingAdapter:
|
||||
def __init__(self, config: NvidiaPostTrainingConfig):
|
||||
self.config = config
|
||||
self.headers = {}
|
||||
|
@ -226,12 +226,8 @@ class NvidiaPostTrainingImpl:
|
|||
# Extract LoRA-specific parameters
|
||||
lora_config = {k: v for k, v in algorithm_config.items() if k != "type"}
|
||||
job_config["hyperparameters"]["lora"] = lora_config
|
||||
|
||||
# Add adapter_dim if available in training_config
|
||||
if training_config.get("algorithm_config", {}).get("adapter_dim"):
|
||||
job_config["hyperparameters"]["lora"]["adapter_dim"] = training_config["algorithm_config"][
|
||||
"adapter_dim"
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||
|
||||
# Create the customization job
|
||||
response = await self._make_request(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue