chore: fix mypy violations in post_training modules

Note: this patch touches all files but post_training.py that will be
significantly changed by #1437, hence leaving it out of the picture for
now.

running_loss is now always Tensor (on-device) and doesn't change its
type from int to Tensor (which made mypy unhappy).

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-11 11:19:45 -04:00
parent 3b35a39b8b
commit 8c01246344
9 changed files with 56 additions and 69 deletions

View file

@ -9,6 +9,9 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
DialogType,
@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
validate_dataset_schema,
)
EXPECTED_DATASET_SCHEMA = {
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
"instruct": [
{
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
@ -41,6 +44,9 @@ async def validate_input_dataset_schema(
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def:
raise ValueError(f"Dataset {dataset_id} does not exist.")
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")