diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index ae9ad5d4c..8021e0e55 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11132,8 +11132,38 @@ "title": "Trace" }, "Checkpoint": { - "description": "Checkpoint created during training runs", - "title": "Checkpoint" + "type": "object", + "properties": { + "identifier": { + "type": "string" + }, + "created_at": { + "type": "string", + "format": "date-time" + }, + "epoch": { + "type": "integer" + }, + "post_training_job_id": { + "type": "string" + }, + "path": { + "type": "string" + }, + "training_metrics": { + "$ref": "#/components/schemas/PostTrainingMetric" + } + }, + "additionalProperties": false, + "required": [ + "identifier", + "created_at", + "epoch", + "post_training_job_id", + "path" + ], + "title": "Checkpoint", + "description": "Checkpoint created during training runs" }, "PostTrainingJobArtifactsResponse": { "type": "object", @@ -11156,6 +11186,31 @@ "title": "PostTrainingJobArtifactsResponse", "description": "Artifacts of a finetuning job." }, + "PostTrainingMetric": { + "type": "object", + "properties": { + "epoch": { + "type": "integer" + }, + "train_loss": { + "type": "number" + }, + "validation_loss": { + "type": "number" + }, + "perplexity": { + "type": "number" + } + }, + "additionalProperties": false, + "required": [ + "epoch", + "train_loss", + "validation_loss", + "perplexity" + ], + "title": "PostTrainingMetric" + }, "PostTrainingJobStatusResponse": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 48cefe12b..a18474646 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -7838,8 +7838,30 @@ components: - start_time title: Trace Checkpoint: - description: Checkpoint created during training runs + type: object + properties: + identifier: + type: string + created_at: + type: string + format: date-time + epoch: + type: integer + post_training_job_id: + type: string + path: + type: string + training_metrics: + $ref: '#/components/schemas/PostTrainingMetric' + additionalProperties: false + required: + - identifier + - created_at + - epoch + - post_training_job_id + - path title: Checkpoint + description: Checkpoint created during training runs PostTrainingJobArtifactsResponse: type: object properties: @@ -7855,6 +7877,24 @@ components: - checkpoints title: PostTrainingJobArtifactsResponse description: Artifacts of a finetuning job. + PostTrainingMetric: + type: object + properties: + epoch: + type: integer + train_loss: + type: number + validation_loss: + type: number + perplexity: + type: number + additionalProperties: false + required: + - epoch + - train_loss + - validation_loss + - perplexity + title: PostTrainingMetric PostTrainingJobStatusResponse: type: object properties: diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index 46cd101af..a2c3b78f1 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -19,8 +19,10 @@ class PostTrainingMetric(BaseModel): perplexity: float -@json_schema_type(schema={"description": "Checkpoint created during training runs"}) +@json_schema_type class Checkpoint(BaseModel): + """Checkpoint created during training runs""" + identifier: str created_at: datetime epoch: int diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 1863b8a50..afdb33b62 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -61,25 +61,25 @@ class RunpodInferenceAdapter( self, model: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: raise NotImplementedError() async def chat_completion( self, model: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -129,10 +129,10 @@ class RunpodInferenceAdapter( async def embeddings( self, model: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/pyproject.toml b/pyproject.toml index e800ed689..b41e03615 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -225,7 +225,6 @@ follow_imports = "silent" # to exclude the entire directory. exclude = [ # As we fix more and more of these, we should remove them from the list - "^llama_stack/apis/common/training_types\\.py$", "^llama_stack/cli/download\\.py$", "^llama_stack/cli/stack/_build\\.py$", "^llama_stack/distribution/build\\.py$",