diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index ae9ad5d4c..8021e0e55 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -11132,8 +11132,38 @@
"title": "Trace"
},
"Checkpoint": {
- "description": "Checkpoint created during training runs",
- "title": "Checkpoint"
+ "type": "object",
+ "properties": {
+ "identifier": {
+ "type": "string"
+ },
+ "created_at": {
+ "type": "string",
+ "format": "date-time"
+ },
+ "epoch": {
+ "type": "integer"
+ },
+ "post_training_job_id": {
+ "type": "string"
+ },
+ "path": {
+ "type": "string"
+ },
+ "training_metrics": {
+ "$ref": "#/components/schemas/PostTrainingMetric"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "identifier",
+ "created_at",
+ "epoch",
+ "post_training_job_id",
+ "path"
+ ],
+ "title": "Checkpoint",
+ "description": "Checkpoint created during training runs"
},
"PostTrainingJobArtifactsResponse": {
"type": "object",
@@ -11156,6 +11186,31 @@
"title": "PostTrainingJobArtifactsResponse",
"description": "Artifacts of a finetuning job."
},
+ "PostTrainingMetric": {
+ "type": "object",
+ "properties": {
+ "epoch": {
+ "type": "integer"
+ },
+ "train_loss": {
+ "type": "number"
+ },
+ "validation_loss": {
+ "type": "number"
+ },
+ "perplexity": {
+ "type": "number"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "epoch",
+ "train_loss",
+ "validation_loss",
+ "perplexity"
+ ],
+ "title": "PostTrainingMetric"
+ },
"PostTrainingJobStatusResponse": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 48cefe12b..a18474646 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -7838,8 +7838,30 @@ components:
- start_time
title: Trace
Checkpoint:
- description: Checkpoint created during training runs
+ type: object
+ properties:
+ identifier:
+ type: string
+ created_at:
+ type: string
+ format: date-time
+ epoch:
+ type: integer
+ post_training_job_id:
+ type: string
+ path:
+ type: string
+ training_metrics:
+ $ref: '#/components/schemas/PostTrainingMetric'
+ additionalProperties: false
+ required:
+ - identifier
+ - created_at
+ - epoch
+ - post_training_job_id
+ - path
title: Checkpoint
+ description: Checkpoint created during training runs
PostTrainingJobArtifactsResponse:
type: object
properties:
@@ -7855,6 +7877,24 @@ components:
- checkpoints
title: PostTrainingJobArtifactsResponse
description: Artifacts of a finetuning job.
+ PostTrainingMetric:
+ type: object
+ properties:
+ epoch:
+ type: integer
+ train_loss:
+ type: number
+ validation_loss:
+ type: number
+ perplexity:
+ type: number
+ additionalProperties: false
+ required:
+ - epoch
+ - train_loss
+ - validation_loss
+ - perplexity
+ title: PostTrainingMetric
PostTrainingJobStatusResponse:
type: object
properties:
diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py
index 46cd101af..a2c3b78f1 100644
--- a/llama_stack/apis/common/training_types.py
+++ b/llama_stack/apis/common/training_types.py
@@ -19,8 +19,10 @@ class PostTrainingMetric(BaseModel):
perplexity: float
-@json_schema_type(schema={"description": "Checkpoint created during training runs"})
+@json_schema_type
class Checkpoint(BaseModel):
+ """Checkpoint created during training runs"""
+
identifier: str
created_at: datetime
epoch: int
diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py
index 1863b8a50..afdb33b62 100644
--- a/llama_stack/providers/remote/inference/runpod/runpod.py
+++ b/llama_stack/providers/remote/inference/runpod/runpod.py
@@ -61,25 +61,25 @@ class RunpodInferenceAdapter(
self,
model: str,
content: InterleavedContent,
- sampling_params: Optional[SamplingParams] = None,
- response_format: Optional[ResponseFormat] = None,
- stream: Optional[bool] = False,
- logprobs: Optional[LogProbConfig] = None,
+ sampling_params: SamplingParams | None = None,
+ response_format: ResponseFormat | None = None,
+ stream: bool | None = False,
+ logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
- messages: List[Message],
- sampling_params: Optional[SamplingParams] = None,
- response_format: Optional[ResponseFormat] = None,
- tools: Optional[List[ToolDefinition]] = None,
- tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = None,
- stream: Optional[bool] = False,
- logprobs: Optional[LogProbConfig] = None,
- tool_config: Optional[ToolConfig] = None,
+ messages: list[Message],
+ sampling_params: SamplingParams | None = None,
+ response_format: ResponseFormat | None = None,
+ tools: list[ToolDefinition] | None = None,
+ tool_choice: ToolChoice | None = ToolChoice.auto,
+ tool_prompt_format: ToolPromptFormat | None = None,
+ stream: bool | None = False,
+ logprobs: LogProbConfig | None = None,
+ tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@@ -129,10 +129,10 @@ class RunpodInferenceAdapter(
async def embeddings(
self,
model: str,
- contents: List[str] | List[InterleavedContentItem],
- text_truncation: Optional[TextTruncation] = TextTruncation.none,
- output_dimension: Optional[int] = None,
- task_type: Optional[EmbeddingTaskType] = None,
+ contents: list[str] | list[InterleavedContentItem],
+ text_truncation: TextTruncation | None = TextTruncation.none,
+ output_dimension: int | None = None,
+ task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/pyproject.toml b/pyproject.toml
index e800ed689..b41e03615 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -225,7 +225,6 @@ follow_imports = "silent"
# to exclude the entire directory.
exclude = [
# As we fix more and more of these, we should remove them from the list
- "^llama_stack/apis/common/training_types\\.py$",
"^llama_stack/cli/download\\.py$",
"^llama_stack/cli/stack/_build\\.py$",
"^llama_stack/distribution/build\\.py$",