From e5936a8df86869066f234d59f02b220d2215513d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 27 Jan 2025 09:18:13 -0800 Subject: [PATCH] Update discriminator to have the correct `mapping` (#881) See https://swagger.io/docs/specification/v3_0/data-models/inheritance-and-polymorphism/#discriminator When specifying discriminators, mapping must be specified unless the value of the discriminator is the subtype itself (which in our case is not.) The changes in the YAML are self-explanatory. --- .../strong_typing/classdef.py | 8 +- .../openapi_generator/strong_typing/schema.py | 11 + docs/resources/llama-stack-spec.html | 609 ++++++++++-------- docs/resources/llama-stack-spec.yaml | 333 ++++++---- llama_stack/apis/agents/agents.py | 16 +- llama_stack/apis/eval/eval.py | 18 +- llama_stack/apis/inference/inference.py | 3 + .../apis/post_training/post_training.py | 11 +- .../scoring_functions/scoring_functions.py | 19 +- llama_stack/apis/telemetry/telemetry.py | 34 +- 10 files changed, 642 insertions(+), 420 deletions(-) diff --git a/docs/openapi_generator/strong_typing/classdef.py b/docs/openapi_generator/strong_typing/classdef.py index 788ecc7e0..b86940420 100644 --- a/docs/openapi_generator/strong_typing/classdef.py +++ b/docs/openapi_generator/strong_typing/classdef.py @@ -122,10 +122,16 @@ class JsonSchemaAnyOf(JsonSchemaNode): anyOf: List["JsonSchemaAny"] +@dataclass +class Discriminator: + propertyName: str + mapping: Dict[str, str] + + @dataclass class JsonSchemaOneOf(JsonSchemaNode): oneOf: List["JsonSchemaAny"] - discriminator: Optional[str] + discriminator: Optional[Discriminator] JsonSchemaAny = Union[ diff --git a/docs/openapi_generator/strong_typing/schema.py b/docs/openapi_generator/strong_typing/schema.py index 5aa41b63f..826efdb4a 100644 --- a/docs/openapi_generator/strong_typing/schema.py +++ b/docs/openapi_generator/strong_typing/schema.py @@ -456,8 +456,19 @@ class JsonSchemaGenerator: ] } if discriminator: + # for each union type, we need to read the value of the discriminator + mapping = {} + for union_type in typing.get_args(typ): + props = self.type_to_schema(union_type, force_expand=True)[ + "properties" + ] + mapping[props[discriminator]["default"]] = self.type_to_schema( + union_type + )["$ref"] + ret["discriminator"] = { "propertyName": discriminator, + "mapping": mapping, } return ret elif origin_type is Literal: diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index f6024c586..7108ee9a5 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3812,7 +3812,11 @@ } ], "discriminator": { - "propertyName": "type" + "propertyName": "type", + "mapping": { + "image": "#/components/schemas/ImageContentItem", + "text": "#/components/schemas/TextContentItem" + } } }, "Message": { @@ -3831,7 +3835,13 @@ } ], "discriminator": { - "propertyName": "role" + "propertyName": "role", + "mapping": { + "user": "#/components/schemas/UserMessage", + "system": "#/components/schemas/SystemMessage", + "tool": "#/components/schemas/ToolResponseMessage", + "assistant": "#/components/schemas/CompletionMessage" + } } }, "SamplingParams": { @@ -3850,7 +3860,12 @@ } ], "discriminator": { - "propertyName": "type" + "propertyName": "type", + "mapping": { + "greedy": "#/components/schemas/GreedySamplingStrategy", + "top_p": "#/components/schemas/TopPSamplingStrategy", + "top_k": "#/components/schemas/TopKSamplingStrategy" + } } }, "max_tokens": { @@ -4313,91 +4328,101 @@ "job_uuid" ] }, + "GrammarResponseFormat": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "grammar", + "default": "grammar" + }, + "bnf": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "type", + "bnf" + ] + }, + "JsonSchemaResponseFormat": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "json_schema", + "default": "json_schema" + }, + "json_schema": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "type", + "json_schema" + ] + }, "ResponseFormat": { "oneOf": [ { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "json_schema", - "default": "json_schema" - }, - "json_schema": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "json_schema" - ] + "$ref": "#/components/schemas/JsonSchemaResponseFormat" }, { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "grammar", - "default": "grammar" - }, - "bnf": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "bnf" - ] + "$ref": "#/components/schemas/GrammarResponseFormat" } ], "discriminator": { - "propertyName": "type" + "propertyName": "type", + "mapping": { + "json_schema": "#/components/schemas/JsonSchemaResponseFormat", + "grammar": "#/components/schemas/GrammarResponseFormat" + } } }, "ChatCompletionRequest": { @@ -4529,7 +4554,12 @@ } ], "discriminator": { - "propertyName": "type" + "propertyName": "type", + "mapping": { + "text": "#/components/schemas/TextDelta", + "image": "#/components/schemas/ImageDelta", + "tool_call": "#/components/schemas/ToolCallDelta" + } } }, "ImageDelta": { @@ -4737,8 +4767,7 @@ "default": "auto" }, "tool_prompt_format": { - "$ref": "#/components/schemas/ToolPromptFormat", - "default": "json" + "$ref": "#/components/schemas/ToolPromptFormat" }, "max_infer_iters": { "type": "integer", @@ -5018,33 +5047,42 @@ "type": "object", "properties": { "payload": { - "oneOf": [ - { - "$ref": "#/components/schemas/AgentTurnResponseStepStartPayload" - }, - { - "$ref": "#/components/schemas/AgentTurnResponseStepProgressPayload" - }, - { - "$ref": "#/components/schemas/AgentTurnResponseStepCompletePayload" - }, - { - "$ref": "#/components/schemas/AgentTurnResponseTurnStartPayload" - }, - { - "$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload" - } - ], - "discriminator": { - "propertyName": "event_type" - } + "$ref": "#/components/schemas/AgentTurnResponseEventPayload" } }, "additionalProperties": false, "required": [ "payload" + ] + }, + "AgentTurnResponseEventPayload": { + "oneOf": [ + { + "$ref": "#/components/schemas/AgentTurnResponseStepStartPayload" + }, + { + "$ref": "#/components/schemas/AgentTurnResponseStepProgressPayload" + }, + { + "$ref": "#/components/schemas/AgentTurnResponseStepCompletePayload" + }, + { + "$ref": "#/components/schemas/AgentTurnResponseTurnStartPayload" + }, + { + "$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload" + } ], - "title": "Streamed agent execution response." + "discriminator": { + "propertyName": "event_type", + "mapping": { + "step_start": "#/components/schemas/AgentTurnResponseStepStartPayload", + "step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload", + "step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload", + "turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload", + "turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload" + } + } }, "AgentTurnResponseStepCompletePayload": { "type": "object", @@ -5082,7 +5120,13 @@ } ], "discriminator": { - "propertyName": "step_type" + "propertyName": "step_type", + "mapping": { + "inference": "#/components/schemas/InferenceStep", + "tool_execution": "#/components/schemas/ToolExecutionStep", + "shield_call": "#/components/schemas/ShieldCallStep", + "memory_retrieval": "#/components/schemas/MemoryRetrievalStep" + } } } }, @@ -5485,7 +5529,13 @@ } ], "discriminator": { - "propertyName": "step_type" + "propertyName": "step_type", + "mapping": { + "inference": "#/components/schemas/InferenceStep", + "tool_execution": "#/components/schemas/ToolExecutionStep", + "shield_call": "#/components/schemas/ShieldCallStep", + "memory_retrieval": "#/components/schemas/MemoryRetrievalStep" + } } } }, @@ -5629,35 +5679,12 @@ "default": "app" }, "eval_candidate": { - "oneOf": [ - { - "$ref": "#/components/schemas/ModelCandidate" - }, - { - "$ref": "#/components/schemas/AgentCandidate" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/EvalCandidate" }, "scoring_params": { "type": "object", "additionalProperties": { - "oneOf": [ - { - "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" - }, - { - "$ref": "#/components/schemas/RegexParserScoringFnParams" - }, - { - "$ref": "#/components/schemas/BasicScoringFnParams" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/ScoringFnParams" } }, "num_examples": { @@ -5700,17 +5727,7 @@ "default": "benchmark" }, "eval_candidate": { - "oneOf": [ - { - "$ref": "#/components/schemas/ModelCandidate" - }, - { - "$ref": "#/components/schemas/AgentCandidate" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/EvalCandidate" }, "num_examples": { "type": "integer" @@ -5722,6 +5739,40 @@ "eval_candidate" ] }, + "EvalCandidate": { + "oneOf": [ + { + "$ref": "#/components/schemas/ModelCandidate" + }, + { + "$ref": "#/components/schemas/AgentCandidate" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "model": "#/components/schemas/ModelCandidate", + "agent": "#/components/schemas/AgentCandidate" + } + } + }, + "EvalTaskConfig": { + "oneOf": [ + { + "$ref": "#/components/schemas/BenchmarkEvalTaskConfig" + }, + { + "$ref": "#/components/schemas/AppEvalTaskConfig" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "benchmark": "#/components/schemas/BenchmarkEvalTaskConfig", + "app": "#/components/schemas/AppEvalTaskConfig" + } + } + }, "LLMAsJudgeScoringFnParams": { "type": "object", "properties": { @@ -5806,6 +5857,27 @@ "type" ] }, + "ScoringFnParams": { + "oneOf": [ + { + "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" + }, + { + "$ref": "#/components/schemas/RegexParserScoringFnParams" + }, + { + "$ref": "#/components/schemas/BasicScoringFnParams" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "llm_as_judge": "#/components/schemas/LLMAsJudgeScoringFnParams", + "regex_parser": "#/components/schemas/RegexParserScoringFnParams", + "basic": "#/components/schemas/BasicScoringFnParams" + } + } + }, "EvaluateRowsRequest": { "type": "object", "properties": { @@ -5844,17 +5916,7 @@ } }, "task_config": { - "oneOf": [ - { - "$ref": "#/components/schemas/BenchmarkEvalTaskConfig" - }, - { - "$ref": "#/components/schemas/AppEvalTaskConfig" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/EvalTaskConfig" } }, "additionalProperties": false, @@ -6019,7 +6081,13 @@ } ], "discriminator": { - "propertyName": "step_type" + "propertyName": "step_type", + "mapping": { + "inference": "#/components/schemas/InferenceStep", + "tool_execution": "#/components/schemas/ToolExecutionStep", + "shield_call": "#/components/schemas/ShieldCallStep", + "memory_retrieval": "#/components/schemas/MemoryRetrievalStep" + } } } }, @@ -6237,7 +6305,19 @@ } ], "discriminator": { - "propertyName": "type" + "propertyName": "type", + "mapping": { + "string": "#/components/schemas/StringType", + "number": "#/components/schemas/NumberType", + "boolean": "#/components/schemas/BooleanType", + "array": "#/components/schemas/ArrayType", + "object": "#/components/schemas/ObjectType", + "json": "#/components/schemas/JsonType", + "union": "#/components/schemas/UnionType", + "chat_completion_input": "#/components/schemas/ChatCompletionInputType", + "completion_input": "#/components/schemas/CompletionInputType", + "agent_turn_input": "#/components/schemas/AgentTurnInputType" + } } }, "StringType": { @@ -6488,20 +6568,7 @@ "$ref": "#/components/schemas/ParamType" }, "params": { - "oneOf": [ - { - "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" - }, - { - "$ref": "#/components/schemas/RegexParserScoringFnParams" - }, - { - "$ref": "#/components/schemas/BasicScoringFnParams" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/ScoringFnParams" } }, "additionalProperties": false, @@ -7415,6 +7482,27 @@ "data" ] }, + "Event": { + "oneOf": [ + { + "$ref": "#/components/schemas/UnstructuredLogEvent" + }, + { + "$ref": "#/components/schemas/MetricEvent" + }, + { + "$ref": "#/components/schemas/StructuredLogEvent" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "unstructured_log": "#/components/schemas/UnstructuredLogEvent", + "metric": "#/components/schemas/MetricEvent", + "structured_log": "#/components/schemas/StructuredLogEvent" + } + } + }, "LogSeverity": { "type": "string", "enum": [ @@ -7580,17 +7668,7 @@ "default": "structured_log" }, "payload": { - "oneOf": [ - { - "$ref": "#/components/schemas/SpanStartPayload" - }, - { - "$ref": "#/components/schemas/SpanEndPayload" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/StructuredLogPayload" } }, "additionalProperties": false, @@ -7602,6 +7680,23 @@ "payload" ] }, + "StructuredLogPayload": { + "oneOf": [ + { + "$ref": "#/components/schemas/SpanStartPayload" + }, + { + "$ref": "#/components/schemas/SpanEndPayload" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "span_start": "#/components/schemas/SpanStartPayload", + "span_end": "#/components/schemas/SpanEndPayload" + } + } + }, "UnstructuredLogEvent": { "type": "object", "properties": { @@ -7666,20 +7761,7 @@ "type": "object", "properties": { "event": { - "oneOf": [ - { - "$ref": "#/components/schemas/UnstructuredLogEvent" - }, - { - "$ref": "#/components/schemas/MetricEvent" - }, - { - "$ref": "#/components/schemas/StructuredLogEvent" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/Event" }, "ttl_seconds": { "type": "integer" @@ -8011,7 +8093,11 @@ } ], "discriminator": { - "propertyName": "type" + "propertyName": "type", + "mapping": { + "default": "#/components/schemas/DefaultRAGQueryGeneratorConfig", + "llm": "#/components/schemas/LLMRAGQueryGeneratorConfig" + } } }, "QueryRequest": { @@ -8394,20 +8480,7 @@ "type": "string" }, "params": { - "oneOf": [ - { - "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" - }, - { - "$ref": "#/components/schemas/RegexParserScoringFnParams" - }, - { - "$ref": "#/components/schemas/BasicScoringFnParams" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/ScoringFnParams" } }, "additionalProperties": false, @@ -8533,17 +8606,7 @@ "type": "object", "properties": { "task_config": { - "oneOf": [ - { - "$ref": "#/components/schemas/BenchmarkEvalTaskConfig" - }, - { - "$ref": "#/components/schemas/AppEvalTaskConfig" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/EvalTaskConfig" } }, "additionalProperties": false, @@ -8682,20 +8745,7 @@ "additionalProperties": { "oneOf": [ { - "oneOf": [ - { - "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" - }, - { - "$ref": "#/components/schemas/RegexParserScoringFnParams" - }, - { - "$ref": "#/components/schemas/BasicScoringFnParams" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/ScoringFnParams" }, { "type": "null" @@ -8736,20 +8786,7 @@ "additionalProperties": { "oneOf": [ { - "oneOf": [ - { - "$ref": "#/components/schemas/LLMAsJudgeScoringFnParams" - }, - { - "$ref": "#/components/schemas/RegexParserScoringFnParams" - }, - { - "$ref": "#/components/schemas/BasicScoringFnParams" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/ScoringFnParams" }, { "type": "null" @@ -8786,6 +8823,23 @@ "results" ] }, + "AlgorithmConfig": { + "oneOf": [ + { + "$ref": "#/components/schemas/LoraFinetuningConfig" + }, + { + "$ref": "#/components/schemas/QATFinetuningConfig" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "LoRA": "#/components/schemas/LoraFinetuningConfig", + "QAT": "#/components/schemas/QATFinetuningConfig" + } + } + }, "LoraFinetuningConfig": { "type": "object", "properties": { @@ -8919,17 +8973,7 @@ "type": "string" }, "algorithm_config": { - "oneOf": [ - { - "$ref": "#/components/schemas/LoraFinetuningConfig" - }, - { - "$ref": "#/components/schemas/QATFinetuningConfig" - } - ], - "discriminator": { - "propertyName": "type" - } + "$ref": "#/components/schemas/AlgorithmConfig" } }, "additionalProperties": false, @@ -9086,7 +9130,11 @@ }, { "name": "AgentTurnResponseEvent", - "description": "Streamed agent execution response.\n\n" + "description": "" + }, + { + "name": "AgentTurnResponseEventPayload", + "description": "" }, { "name": "AgentTurnResponseStepCompletePayload", @@ -9119,6 +9167,10 @@ "name": "AggregationFunctionType", "description": "" }, + { + "name": "AlgorithmConfig", + "description": "" + }, { "name": "AppEvalTaskConfig", "description": "" @@ -9275,10 +9327,18 @@ { "name": "Eval" }, + { + "name": "EvalCandidate", + "description": "" + }, { "name": "EvalTask", "description": "" }, + { + "name": "EvalTaskConfig", + "description": "" + }, { "name": "EvalTasks" }, @@ -9290,6 +9350,14 @@ "name": "EvaluateRowsRequest", "description": "" }, + { + "name": "Event", + "description": "" + }, + { + "name": "GrammarResponseFormat", + "description": "" + }, { "name": "GreedySamplingStrategy", "description": "" @@ -9344,6 +9412,10 @@ "name": "JobStatus", "description": "" }, + { + "name": "JsonSchemaResponseFormat", + "description": "" + }, { "name": "JsonType", "description": "" @@ -9628,6 +9700,10 @@ "name": "ScoringFn", "description": "" }, + { + "name": "ScoringFnParams", + "description": "" + }, { "name": "ScoringFunctions" }, @@ -9682,6 +9758,10 @@ "name": "StructuredLogEvent", "description": "" }, + { + "name": "StructuredLogPayload", + "description": "" + }, { "name": "SupervisedFineTuneRequest", "description": "" @@ -9878,6 +9958,7 @@ "AgentTool", "AgentTurnInputType", "AgentTurnResponseEvent", + "AgentTurnResponseEventPayload", "AgentTurnResponseStepCompletePayload", "AgentTurnResponseStepProgressPayload", "AgentTurnResponseStepStartPayload", @@ -9885,6 +9966,7 @@ "AgentTurnResponseTurnCompletePayload", "AgentTurnResponseTurnStartPayload", "AggregationFunctionType", + "AlgorithmConfig", "AppEvalTaskConfig", "AppendRowsRequest", "ArrayType", @@ -9921,9 +10003,13 @@ "EfficiencyConfig", "EmbeddingsRequest", "EmbeddingsResponse", + "EvalCandidate", "EvalTask", + "EvalTaskConfig", "EvaluateResponse", "EvaluateRowsRequest", + "Event", + "GrammarResponseFormat", "GreedySamplingStrategy", "HealthInfo", "ImageContentItem", @@ -9936,6 +10022,7 @@ "InvokeToolRequest", "Job", "JobStatus", + "JsonSchemaResponseFormat", "JsonType", "LLMAsJudgeScoringFnParams", "LLMRAGQueryGeneratorConfig", @@ -10004,6 +10091,7 @@ "ScoreRequest", "ScoreResponse", "ScoringFn", + "ScoringFnParams", "ScoringResult", "Session", "Shield", @@ -10016,6 +10104,7 @@ "StopReason", "StringType", "StructuredLogEvent", + "StructuredLogPayload", "SupervisedFineTuneRequest", "SyntheticDataGenerateRequest", "SyntheticDataGenerationResponse", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 21df2d96f..a7095716c 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -45,7 +45,6 @@ components: default: auto tool_prompt_format: $ref: '#/components/schemas/ToolPromptFormat' - default: json toolgroups: items: $ref: '#/components/schemas/AgentTool' @@ -77,6 +76,11 @@ components: properties: step: discriminator: + mapping: + inference: '#/components/schemas/InferenceStep' + memory_retrieval: '#/components/schemas/MemoryRetrievalStep' + shield_call: '#/components/schemas/ShieldCallStep' + tool_execution: '#/components/schemas/ToolExecutionStep' propertyName: step_type oneOf: - $ref: '#/components/schemas/InferenceStep' @@ -121,18 +125,25 @@ components: additionalProperties: false properties: payload: - discriminator: - propertyName: event_type - oneOf: - - $ref: '#/components/schemas/AgentTurnResponseStepStartPayload' - - $ref: '#/components/schemas/AgentTurnResponseStepProgressPayload' - - $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload' - - $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload' - - $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload' + $ref: '#/components/schemas/AgentTurnResponseEventPayload' required: - payload - title: Streamed agent execution response. type: object + AgentTurnResponseEventPayload: + discriminator: + mapping: + step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload' + step_progress: '#/components/schemas/AgentTurnResponseStepProgressPayload' + step_start: '#/components/schemas/AgentTurnResponseStepStartPayload' + turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload' + turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload' + propertyName: event_type + oneOf: + - $ref: '#/components/schemas/AgentTurnResponseStepStartPayload' + - $ref: '#/components/schemas/AgentTurnResponseStepProgressPayload' + - $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload' + - $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload' + - $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload' AgentTurnResponseStepCompletePayload: additionalProperties: false properties: @@ -142,6 +153,11 @@ components: type: string step_details: discriminator: + mapping: + inference: '#/components/schemas/InferenceStep' + memory_retrieval: '#/components/schemas/MemoryRetrievalStep' + shield_call: '#/components/schemas/ShieldCallStep' + tool_execution: '#/components/schemas/ToolExecutionStep' propertyName: step_type oneOf: - $ref: '#/components/schemas/InferenceStep' @@ -260,25 +276,25 @@ components: - categorical_count - accuracy type: string + AlgorithmConfig: + discriminator: + mapping: + LoRA: '#/components/schemas/LoraFinetuningConfig' + QAT: '#/components/schemas/QATFinetuningConfig' + propertyName: type + oneOf: + - $ref: '#/components/schemas/LoraFinetuningConfig' + - $ref: '#/components/schemas/QATFinetuningConfig' AppEvalTaskConfig: additionalProperties: false properties: eval_candidate: - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/ModelCandidate' - - $ref: '#/components/schemas/AgentCandidate' + $ref: '#/components/schemas/EvalCandidate' num_examples: type: integer scoring_params: additionalProperties: - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - - $ref: '#/components/schemas/RegexParserScoringFnParams' - - $ref: '#/components/schemas/BasicScoringFnParams' + $ref: '#/components/schemas/ScoringFnParams' type: object type: const: app @@ -412,11 +428,7 @@ components: additionalProperties: false properties: eval_candidate: - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/ModelCandidate' - - $ref: '#/components/schemas/AgentCandidate' + $ref: '#/components/schemas/EvalCandidate' num_examples: type: integer type: @@ -632,6 +644,10 @@ components: type: object ContentDelta: discriminator: + mapping: + image: '#/components/schemas/ImageDelta' + text: '#/components/schemas/TextDelta' + tool_call: '#/components/schemas/ToolCallDelta' propertyName: type oneOf: - $ref: '#/components/schemas/TextDelta' @@ -830,6 +846,15 @@ components: required: - embeddings type: object + EvalCandidate: + discriminator: + mapping: + agent: '#/components/schemas/AgentCandidate' + model: '#/components/schemas/ModelCandidate' + propertyName: type + oneOf: + - $ref: '#/components/schemas/ModelCandidate' + - $ref: '#/components/schemas/AgentCandidate' EvalTask: additionalProperties: false properties: @@ -868,6 +893,15 @@ components: - scoring_functions - metadata type: object + EvalTaskConfig: + discriminator: + mapping: + app: '#/components/schemas/AppEvalTaskConfig' + benchmark: '#/components/schemas/BenchmarkEvalTaskConfig' + propertyName: type + oneOf: + - $ref: '#/components/schemas/BenchmarkEvalTaskConfig' + - $ref: '#/components/schemas/AppEvalTaskConfig' EvaluateResponse: additionalProperties: false properties: @@ -911,16 +945,44 @@ components: type: string type: array task_config: - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/BenchmarkEvalTaskConfig' - - $ref: '#/components/schemas/AppEvalTaskConfig' + $ref: '#/components/schemas/EvalTaskConfig' required: - input_rows - scoring_functions - task_config type: object + Event: + discriminator: + mapping: + metric: '#/components/schemas/MetricEvent' + structured_log: '#/components/schemas/StructuredLogEvent' + unstructured_log: '#/components/schemas/UnstructuredLogEvent' + propertyName: type + oneOf: + - $ref: '#/components/schemas/UnstructuredLogEvent' + - $ref: '#/components/schemas/MetricEvent' + - $ref: '#/components/schemas/StructuredLogEvent' + GrammarResponseFormat: + additionalProperties: false + properties: + bnf: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + type: + const: grammar + default: grammar + type: string + required: + - type + - bnf + type: object GreedySamplingStrategy: additionalProperties: false properties: @@ -1055,6 +1117,9 @@ components: type: array InterleavedContentItem: discriminator: + mapping: + image: '#/components/schemas/ImageContentItem' + text: '#/components/schemas/TextContentItem' propertyName: type oneOf: - $ref: '#/components/schemas/ImageContentItem' @@ -1093,6 +1158,27 @@ components: - failed - scheduled type: string + JsonSchemaResponseFormat: + additionalProperties: false + properties: + json_schema: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + type: + const: json_schema + default: json_schema + type: string + required: + - type + - json_schema + type: object JsonType: additionalProperties: false properties: @@ -1262,12 +1348,7 @@ components: additionalProperties: false properties: event: - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/UnstructuredLogEvent' - - $ref: '#/components/schemas/MetricEvent' - - $ref: '#/components/schemas/StructuredLogEvent' + $ref: '#/components/schemas/Event' ttl_seconds: type: integer required: @@ -1346,6 +1427,11 @@ components: type: object Message: discriminator: + mapping: + assistant: '#/components/schemas/CompletionMessage' + system: '#/components/schemas/SystemMessage' + tool: '#/components/schemas/ToolResponseMessage' + user: '#/components/schemas/UserMessage' propertyName: role oneOf: - $ref: '#/components/schemas/UserMessage' @@ -1518,6 +1604,17 @@ components: type: object ParamType: discriminator: + mapping: + agent_turn_input: '#/components/schemas/AgentTurnInputType' + array: '#/components/schemas/ArrayType' + boolean: '#/components/schemas/BooleanType' + chat_completion_input: '#/components/schemas/ChatCompletionInputType' + completion_input: '#/components/schemas/CompletionInputType' + json: '#/components/schemas/JsonType' + number: '#/components/schemas/NumberType' + object: '#/components/schemas/ObjectType' + string: '#/components/schemas/StringType' + union: '#/components/schemas/UnionType' propertyName: type oneOf: - $ref: '#/components/schemas/StringType' @@ -1830,6 +1927,9 @@ components: type: object RAGQueryGeneratorConfig: discriminator: + mapping: + default: '#/components/schemas/DefaultRAGQueryGeneratorConfig' + llm: '#/components/schemas/LLMRAGQueryGeneratorConfig' propertyName: type oneOf: - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig' @@ -1948,12 +2048,7 @@ components: description: type: string params: - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - - $ref: '#/components/schemas/RegexParserScoringFnParams' - - $ref: '#/components/schemas/BasicScoringFnParams' + $ref: '#/components/schemas/ScoringFnParams' provider_id: type: string provider_scoring_fn_id: @@ -2031,48 +2126,13 @@ components: type: object ResponseFormat: discriminator: + mapping: + grammar: '#/components/schemas/GrammarResponseFormat' + json_schema: '#/components/schemas/JsonSchemaResponseFormat' propertyName: type oneOf: - - additionalProperties: false - properties: - json_schema: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: json_schema - default: json_schema - type: string - required: - - type - - json_schema - type: object - - additionalProperties: false - properties: - bnf: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: grammar - default: grammar - type: string - required: - - type - - bnf - type: object + - $ref: '#/components/schemas/JsonSchemaResponseFormat' + - $ref: '#/components/schemas/GrammarResponseFormat' RouteInfo: additionalProperties: false properties: @@ -2093,11 +2153,7 @@ components: additionalProperties: false properties: task_config: - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/BenchmarkEvalTaskConfig' - - $ref: '#/components/schemas/AppEvalTaskConfig' + $ref: '#/components/schemas/EvalTaskConfig' required: - task_config type: object @@ -2163,6 +2219,10 @@ components: type: number strategy: discriminator: + mapping: + greedy: '#/components/schemas/GreedySamplingStrategy' + top_k: '#/components/schemas/TopKSamplingStrategy' + top_p: '#/components/schemas/TopPSamplingStrategy' propertyName: type oneOf: - $ref: '#/components/schemas/GreedySamplingStrategy' @@ -2201,12 +2261,7 @@ components: scoring_functions: additionalProperties: oneOf: - - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - - $ref: '#/components/schemas/RegexParserScoringFnParams' - - $ref: '#/components/schemas/BasicScoringFnParams' + - $ref: '#/components/schemas/ScoringFnParams' - type: 'null' type: object required: @@ -2244,12 +2299,7 @@ components: scoring_functions: additionalProperties: oneOf: - - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - - $ref: '#/components/schemas/RegexParserScoringFnParams' - - $ref: '#/components/schemas/BasicScoringFnParams' + - $ref: '#/components/schemas/ScoringFnParams' - type: 'null' type: object required: @@ -2284,12 +2334,7 @@ components: - type: object type: object params: - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - - $ref: '#/components/schemas/RegexParserScoringFnParams' - - $ref: '#/components/schemas/BasicScoringFnParams' + $ref: '#/components/schemas/ScoringFnParams' provider_id: type: string provider_resource_id: @@ -2308,6 +2353,17 @@ components: - metadata - return_type type: object + ScoringFnParams: + discriminator: + mapping: + basic: '#/components/schemas/BasicScoringFnParams' + llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' + regex_parser: '#/components/schemas/RegexParserScoringFnParams' + propertyName: type + oneOf: + - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' + - $ref: '#/components/schemas/RegexParserScoringFnParams' + - $ref: '#/components/schemas/BasicScoringFnParams' ScoringResult: additionalProperties: false properties: @@ -2543,11 +2599,7 @@ components: - type: object type: object payload: - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/SpanStartPayload' - - $ref: '#/components/schemas/SpanEndPayload' + $ref: '#/components/schemas/StructuredLogPayload' span_id: type: string timestamp: @@ -2566,15 +2618,20 @@ components: - type - payload type: object + StructuredLogPayload: + discriminator: + mapping: + span_end: '#/components/schemas/SpanEndPayload' + span_start: '#/components/schemas/SpanStartPayload' + propertyName: type + oneOf: + - $ref: '#/components/schemas/SpanStartPayload' + - $ref: '#/components/schemas/SpanEndPayload' SupervisedFineTuneRequest: additionalProperties: false properties: algorithm_config: - discriminator: - propertyName: type - oneOf: - - $ref: '#/components/schemas/LoraFinetuningConfig' - - $ref: '#/components/schemas/QATFinetuningConfig' + $ref: '#/components/schemas/AlgorithmConfig' checkpoint_dir: type: string hyperparam_search_config: @@ -3160,6 +3217,11 @@ components: steps: items: discriminator: + mapping: + inference: '#/components/schemas/InferenceStep' + memory_retrieval: '#/components/schemas/MemoryRetrievalStep' + shield_call: '#/components/schemas/ShieldCallStep' + tool_execution: '#/components/schemas/ToolExecutionStep' propertyName: step_type oneOf: - $ref: '#/components/schemas/InferenceStep' @@ -5687,11 +5749,12 @@ tags: - description: name: AgentTurnInputType -- description: 'Streamed agent execution response. - - - ' +- description: name: AgentTurnResponseEvent +- description: + name: AgentTurnResponseEventPayload - description: name: AgentTurnResponseStepCompletePayload @@ -5717,6 +5780,9 @@ tags: - description: name: AggregationFunctionType +- description: + name: AlgorithmConfig - description: name: AppEvalTaskConfig @@ -5837,8 +5903,12 @@ tags: /> name: EmbeddingsResponse - name: Eval +- description: + name: EvalCandidate - description: name: EvalTask +- description: + name: EvalTaskConfig - name: EvalTasks - description: @@ -5846,6 +5916,11 @@ tags: - description: name: EvaluateRowsRequest +- description: + name: Event +- description: + name: GrammarResponseFormat - description: name: GreedySamplingStrategy @@ -5878,6 +5953,9 @@ tags: name: Job - description: name: JobStatus +- description: + name: JsonSchemaResponseFormat - description: name: JsonType - description: name: ScoringFn +- description: + name: ScoringFnParams - name: ScoringFunctions - description: name: ScoringResult @@ -6102,6 +6183,9 @@ tags: - description: name: StructuredLogEvent +- description: + name: StructuredLogPayload - description: name: SupervisedFineTuneRequest @@ -6239,6 +6323,7 @@ x-tagGroups: - AgentTool - AgentTurnInputType - AgentTurnResponseEvent + - AgentTurnResponseEventPayload - AgentTurnResponseStepCompletePayload - AgentTurnResponseStepProgressPayload - AgentTurnResponseStepStartPayload @@ -6246,6 +6331,7 @@ x-tagGroups: - AgentTurnResponseTurnCompletePayload - AgentTurnResponseTurnStartPayload - AggregationFunctionType + - AlgorithmConfig - AppEvalTaskConfig - AppendRowsRequest - ArrayType @@ -6282,9 +6368,13 @@ x-tagGroups: - EfficiencyConfig - EmbeddingsRequest - EmbeddingsResponse + - EvalCandidate - EvalTask + - EvalTaskConfig - EvaluateResponse - EvaluateRowsRequest + - Event + - GrammarResponseFormat - GreedySamplingStrategy - HealthInfo - ImageContentItem @@ -6297,6 +6387,7 @@ x-tagGroups: - InvokeToolRequest - Job - JobStatus + - JsonSchemaResponseFormat - JsonType - LLMAsJudgeScoringFnParams - LLMRAGQueryGeneratorConfig @@ -6365,6 +6456,7 @@ x-tagGroups: - ScoreRequest - ScoreResponse - ScoringFn + - ScoringFnParams - ScoringResult - Session - Shield @@ -6377,6 +6469,7 @@ x-tagGroups: - StopReason - StringType - StructuredLogEvent + - StructuredLogPayload - SupervisedFineTuneRequest - SyntheticDataGenerateRequest - SyntheticDataGenerationResponse diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 9b77ab8c7..f62d78390 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -229,11 +229,8 @@ class AgentTurnResponseTurnCompletePayload(BaseModel): turn: Turn -@json_schema_type -class AgentTurnResponseEvent(BaseModel): - """Streamed agent execution response.""" - - payload: Annotated[ +AgentTurnResponseEventPayload = register_schema( + Annotated[ Union[ AgentTurnResponseStepStartPayload, AgentTurnResponseStepProgressPayload, @@ -242,7 +239,14 @@ class AgentTurnResponseEvent(BaseModel): AgentTurnResponseTurnCompletePayload, ], Field(discriminator="event_type"), - ] + ], + name="AgentTurnResponseEventPayload", +) + + +@json_schema_type +class AgentTurnResponseEvent(BaseModel): + payload: AgentTurnResponseEventPayload @json_schema_type diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index c9d2fb70b..dfeff0918 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, Union -from llama_models.schema_utils import json_schema_type, webmethod +from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -31,9 +31,10 @@ class AgentCandidate(BaseModel): config: AgentConfig -EvalCandidate = Annotated[ - Union[ModelCandidate, AgentCandidate], Field(discriminator="type") -] +EvalCandidate = register_schema( + Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")], + name="EvalCandidate", +) @json_schema_type @@ -61,9 +62,12 @@ class AppEvalTaskConfig(BaseModel): # we could optinally add any specific dataset config here -EvalTaskConfig = Annotated[ - Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type") -] +EvalTaskConfig = register_schema( + Annotated[ + Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type") + ], + name="EvalTaskConfig", +) @json_schema_type diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index fdda5fe1b..871f1f633 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -157,11 +157,13 @@ class ChatCompletionResponseEvent(BaseModel): stop_reason: Optional[StopReason] = None +@json_schema_type class ResponseFormatType(Enum): json_schema = "json_schema" grammar = "grammar" +@json_schema_type class JsonSchemaResponseFormat(BaseModel): type: Literal[ResponseFormatType.json_schema.value] = ( ResponseFormatType.json_schema.value @@ -169,6 +171,7 @@ class JsonSchemaResponseFormat(BaseModel): json_schema: Dict[str, Any] +@json_schema_type class GrammarResponseFormat(BaseModel): type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value bnf: Dict[str, Any] diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index b9aa3bbde..675488ada 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -8,7 +8,7 @@ from datetime import datetime from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, Union -from llama_models.schema_utils import json_schema_type, webmethod +from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -88,9 +88,12 @@ class QATFinetuningConfig(BaseModel): group_size: int -AlgorithmConfig = Annotated[ - Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type") -] +AlgorithmConfig = register_schema( + Annotated[ + Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type") + ], + name="AlgorithmConfig", +) @json_schema_type diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 3089dc0a4..b2e85f855 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -16,7 +16,7 @@ from typing import ( Union, ) -from llama_models.schema_utils import json_schema_type, webmethod +from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -82,14 +82,17 @@ class BasicScoringFnParams(BaseModel): ) -ScoringFnParams = Annotated[ - Union[ - LLMAsJudgeScoringFnParams, - RegexParserScoringFnParams, - BasicScoringFnParams, +ScoringFnParams = register_schema( + Annotated[ + Union[ + LLMAsJudgeScoringFnParams, + RegexParserScoringFnParams, + BasicScoringFnParams, + ], + Field(discriminator="type"), ], - Field(discriminator="type"), -] + name="ScoringFnParams", +) class CommonScoringFnFields(BaseModel): diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 30a4e2342..284e3a970 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -17,7 +17,7 @@ from typing import ( Union, ) -from llama_models.schema_utils import json_schema_type, webmethod +from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -115,13 +115,16 @@ class SpanEndPayload(BaseModel): status: SpanStatus -StructuredLogPayload = Annotated[ - Union[ - SpanStartPayload, - SpanEndPayload, +StructuredLogPayload = register_schema( + Annotated[ + Union[ + SpanStartPayload, + SpanEndPayload, + ], + Field(discriminator="type"), ], - Field(discriminator="type"), -] + name="StructuredLogPayload", +) @json_schema_type @@ -130,14 +133,17 @@ class StructuredLogEvent(EventCommon): payload: StructuredLogPayload -Event = Annotated[ - Union[ - UnstructuredLogEvent, - MetricEvent, - StructuredLogEvent, +Event = register_schema( + Annotated[ + Union[ + UnstructuredLogEvent, + MetricEvent, + StructuredLogEvent, + ], + Field(discriminator="type"), ], - Field(discriminator="type"), -] + name="Event", +) @json_schema_type