mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 05:52:26 +00:00
chore(misc): make tests and starter faster
This commit is contained in:
parent
ea46f74092
commit
2b4e88a3de
19 changed files with 2860 additions and 1660 deletions
|
|
@ -5,8 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
|
@ -44,6 +42,8 @@ class PandasDataframeDataset:
|
|||
if self.dataset_def.source.type == "uri":
|
||||
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||
elif self.dataset_def.source.type == "rows":
|
||||
import pandas
|
||||
|
||||
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||
|
|
@ -103,6 +103,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
return paginate_records(records, start_index, limit)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
import pandas
|
||||
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
await dataset_impl.load()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@
|
|||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
import tempfile
|
||||
|
||||
|
||||
class HuggingFacePostTrainingConfig(BaseModel):
|
||||
|
|
@ -71,7 +72,7 @@ class HuggingFacePostTrainingConfig(BaseModel):
|
|||
dpo_beta: float = 0.1
|
||||
use_reference_model: bool = True
|
||||
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
|
||||
dpo_output_dir: str = "./checkpoints/dpo"
|
||||
dpo_output_dir: str = Field(default_factory=lambda: tempfile.mkdtemp(prefix="dpo_output_"))
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -22,15 +22,8 @@ from llama_stack.apis.post_training import (
|
|||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||
HuggingFacePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||
HFFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
|
||||
HFDPOAlignmentSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
|
|
@ -85,6 +78,10 @@ class HuggingFacePostTrainingImpl:
|
|||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||
HFFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting HF finetuning")
|
||||
|
||||
recipe = HFFinetuningSingleDevice(
|
||||
|
|
@ -124,6 +121,10 @@ class HuggingFacePostTrainingImpl:
|
|||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
|
||||
HFDPOAlignmentSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting HF DPO alignment")
|
||||
|
||||
recipe = HFDPOAlignmentSingleDevice(
|
||||
|
|
@ -168,7 +169,6 @@ class HuggingFacePostTrainingImpl:
|
|||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
|
|
@ -195,16 +195,13 @@ class HuggingFacePostTrainingImpl:
|
|||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
|
|
|
|||
|
|
@ -23,9 +23,6 @@ from llama_stack.apis.post_training import (
|
|||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
|
@ -84,6 +81,10 @@ class TorchtunePostTrainingImpl:
|
|||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting Lora finetuning")
|
||||
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@
|
|||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import datasets as hf_datasets
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
|
@ -73,6 +71,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
import datasets as hf_datasets
|
||||
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
|
@ -81,6 +81,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
return paginate_records(records, start_index, limit)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
import datasets as hf_datasets
|
||||
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
|
|
|||
|
|
@ -112,7 +112,8 @@ class OllamaInferenceAdapter(
|
|||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
if self._openai_client is None:
|
||||
self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama")
|
||||
url = self.config.url.rstrip("/")
|
||||
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
|
||||
return self._openai_client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ import base64
|
|||
import io
|
||||
from urllib.parse import unquote
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
|
||||
async def get_dataframe_from_uri(uri: str):
|
||||
import pandas
|
||||
|
||||
df = None
|
||||
if uri.endswith(".csv"):
|
||||
# Moving to its own thread to avoid io from blocking the eventloop
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue