chore(misc): make tests and starter faster

This commit is contained in:
Ashwin Bharambe 2025-08-05 13:57:15 -07:00
parent ea46f74092
commit 2b4e88a3de
19 changed files with 2860 additions and 1660 deletions

View file

@ -5,8 +5,6 @@
# the root directory of this source tree.
from typing import Any
import pandas
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
@ -44,6 +42,8 @@ class PandasDataframeDataset:
if self.dataset_def.source.type == "uri":
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows":
import pandas
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
@ -103,6 +103,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
return paginate_records(records, start_index, limit)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
import pandas
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
await dataset_impl.load()

View file

@ -6,7 +6,8 @@
from typing import Any, Literal
from pydantic import BaseModel
from pydantic import BaseModel, Field
import tempfile
class HuggingFacePostTrainingConfig(BaseModel):
@ -71,7 +72,7 @@ class HuggingFacePostTrainingConfig(BaseModel):
dpo_beta: float = 0.1
use_reference_model: bool = True
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
dpo_output_dir: str = "./checkpoints/dpo"
dpo_output_dir: str = Field(default_factory=lambda: tempfile.mkdtemp(prefix="dpo_output_"))
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:

View file

@ -22,15 +22,8 @@ from llama_stack.apis.post_training import (
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
HFDPOAlignmentSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
@ -85,6 +78,10 @@ class HuggingFacePostTrainingImpl:
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
on_log_message_cb("Starting HF finetuning")
recipe = HFFinetuningSingleDevice(
@ -124,6 +121,10 @@ class HuggingFacePostTrainingImpl:
logger_config: dict[str, Any],
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
HFDPOAlignmentSingleDevice,
)
on_log_message_cb("Starting HF DPO alignment")
recipe = HFDPOAlignmentSingleDevice(
@ -168,7 +169,6 @@ class HuggingFacePostTrainingImpl:
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
@ -195,16 +195,13 @@ class HuggingFacePostTrainingImpl:
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]

View file

@ -23,9 +23,6 @@ from llama_stack.apis.post_training import (
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
@ -84,6 +81,10 @@ class TorchtunePostTrainingImpl:
if isinstance(algorithm_config, LoraFinetuningConfig):
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
on_log_message_cb("Starting Lora finetuning")
recipe = LoraFinetuningSingleDevice(

View file

@ -6,8 +6,6 @@
from typing import Any
from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
@ -73,6 +71,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
import datasets as hf_datasets
dataset_def = self.dataset_infos[dataset_id]
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
@ -81,6 +81,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
return paginate_records(records, start_index, limit)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
import datasets as hf_datasets
dataset_def = self.dataset_infos[dataset_id]
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)

View file

@ -112,7 +112,8 @@ class OllamaInferenceAdapter(
@property
def openai_client(self) -> AsyncOpenAI:
if self._openai_client is None:
self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama")
url = self.config.url.rstrip("/")
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
return self._openai_client
async def initialize(self) -> None:

View file

@ -9,12 +9,12 @@ import base64
import io
from urllib.parse import unquote
import pandas
from llama_stack.providers.utils.memory.vector_store import parse_data_url
async def get_dataframe_from_uri(uri: str):
import pandas
df = None
if uri.endswith(".csv"):
# Moving to its own thread to avoid io from blocking the eventloop