fix: remove extra sft args in NvidiaPostTrainingAdapter

The supervised_fine_tune method in NvidiaPostTrainingAdapter had some
extra args that aren't part of the post_training protocol, and these
extra args were causing FastAPI to throw an error when attempting to
stand up an endpoint that used this provider.

(Closes #1938)

Before this change, bringing up a stack with the `nvidia` template
failed. Afterwards, it passes. I'm testing this like:

```
INFERENCE_MODEL="meta/llama-3.1-8b-instruct" \
llama stack build --template nvidia --image-type venv --run
```

I also ensured the nvidia/test_supervised_fine_tuning.py tests still
pass via:

```
python -m pytest \
  tests/unit/providers/nvidia/test_supervised_fine_tuning.py
```

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-04-11 09:46:16 -04:00
parent 6aa459b00c
commit c2d23ddd75

View file

@ -206,10 +206,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig] = None,
extra_json: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
**kwargs,
) -> NvidiaPostTrainingJob:
"""
Fine-tunes a model on a dataset.