diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5cc910a55..75f0dddd1 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -189,13 +189,11 @@ class AgentToolGroupWithArgs(BaseModel): args: Dict[str, Any] -AgentToolGroup = register_schema( - Union[ - str, - AgentToolGroupWithArgs, - ], - name="AgentTool", -) +AgentToolGroup = Union[ + str, + AgentToolGroupWithArgs, +] +register_schema(AgentToolGroup, name="AgentTool") class AgentConfigCommon(BaseModel): @@ -312,20 +310,18 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): turn: Turn -AgentTurnResponseEventPayload = register_schema( - Annotated[ - Union[ - AgentTurnResponseStepStartPayload, - AgentTurnResponseStepProgressPayload, - AgentTurnResponseStepCompletePayload, - AgentTurnResponseTurnStartPayload, - AgentTurnResponseTurnCompletePayload, - AgentTurnResponseTurnAwaitingInputPayload, - ], - Field(discriminator="event_type"), +AgentTurnResponseEventPayload = Annotated[ + Union[ + AgentTurnResponseStepStartPayload, + AgentTurnResponseStepProgressPayload, + AgentTurnResponseStepCompletePayload, + AgentTurnResponseTurnStartPayload, + AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnAwaitingInputPayload, ], - name="AgentTurnResponseEventPayload", -) + Field(discriminator="event_type"), +] +register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload") @json_schema_type diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 0d0afa894..9d4e21308 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -63,19 +63,15 @@ class TextContentItem(BaseModel): # other modalities can be added here -InterleavedContentItem = register_schema( - Annotated[ - Union[ImageContentItem, TextContentItem], - Field(discriminator="type"), - ], - name="InterleavedContentItem", -) +InterleavedContentItem = Annotated[ + Union[ImageContentItem, TextContentItem], + Field(discriminator="type"), +] +register_schema(InterleavedContentItem, name="InterleavedContentItem") # accept a single "str" as a special case since it is common -InterleavedContent = register_schema( - Union[str, InterleavedContentItem, List[InterleavedContentItem]], - name="InterleavedContent", -) +InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]] +register_schema(InterleavedContent, name="InterleavedContent") @json_schema_type @@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel): # streaming completions send a stream of ContentDeltas -ContentDelta = register_schema( - Annotated[ - Union[TextDelta, ImageDelta, ToolCallDelta], - Field(discriminator="type"), - ], - name="ContentDelta", -) +ContentDelta = Annotated[ + Union[TextDelta, ImageDelta, ToolCallDelta], + Field(discriminator="type"), +] +register_schema(ContentDelta, name="ContentDelta") diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index d7746df8d..5d9f000be 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -72,24 +72,22 @@ class DialogType(BaseModel): type: Literal["dialog"] = "dialog" -ParamType = register_schema( - Annotated[ - Union[ - StringType, - NumberType, - BooleanType, - ArrayType, - ObjectType, - JsonType, - UnionType, - ChatCompletionInputType, - CompletionInputType, - AgentTurnInputType, - ], - Field(discriminator="type"), +ParamType = Annotated[ + Union[ + StringType, + NumberType, + BooleanType, + ArrayType, + ObjectType, + JsonType, + UnionType, + ChatCompletionInputType, + CompletionInputType, + AgentTurnInputType, ], - name="ParamType", -) + Field(discriminator="type"), +] +register_schema(ParamType, name="ParamType") """ # TODO: recursive definition of ParamType in these containers diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index e2c940f64..32ccde144 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -84,13 +84,11 @@ class RowsDataSource(BaseModel): rows: List[Dict[str, Any]] -DataSource = register_schema( - Annotated[ - Union[URIDataSource, RowsDataSource], - Field(discriminator="type"), - ], - name="DataSource", -) +DataSource = Annotated[ + Union[URIDataSource, RowsDataSource], + Field(discriminator="type"), +] +register_schema(DataSource, name="DataSource") class CommonDatasetFields(BaseModel): diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 51c38b16a..d05786321 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -43,10 +43,8 @@ class AgentCandidate(BaseModel): config: AgentConfig -EvalCandidate = register_schema( - Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")], - name="EvalCandidate", -) +EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")] +register_schema(EvalCandidate, name="EvalCandidate") @json_schema_type diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 0a4324cdf..7d3539dcb 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -144,18 +144,16 @@ class CompletionMessage(BaseModel): tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) -Message = register_schema( - Annotated[ - Union[ - UserMessage, - SystemMessage, - ToolResponseMessage, - CompletionMessage, - ], - Field(discriminator="role"), +Message = Annotated[ + Union[ + UserMessage, + SystemMessage, + ToolResponseMessage, + CompletionMessage, ], - name="Message", -) + Field(discriminator="role"), +] +register_schema(Message, name="Message") @json_schema_type @@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel): bnf: Dict[str, Any] -ResponseFormat = register_schema( - Annotated[ - Union[JsonSchemaResponseFormat, GrammarResponseFormat], - Field(discriminator="type"), - ], - name="ResponseFormat", -) +ResponseFormat = Annotated[ + Union[JsonSchemaResponseFormat, GrammarResponseFormat], + Field(discriminator="type"), +] +register_schema(ResponseFormat, name="ResponseFormat") # This is an internally used class diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 362f87a26..e61c0e4e4 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel): group_size: int -AlgorithmConfig = register_schema( - Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")], - name="AlgorithmConfig", -) +AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")] +register_schema(AlgorithmConfig, name="AlgorithmConfig") @json_schema_type diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 57761c940..4f85947dd 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -79,17 +79,15 @@ class BasicScoringFnParams(BaseModel): ) -ScoringFnParams = register_schema( - Annotated[ - Union[ - LLMAsJudgeScoringFnParams, - RegexParserScoringFnParams, - BasicScoringFnParams, - ], - Field(discriminator="type"), +ScoringFnParams = Annotated[ + Union[ + LLMAsJudgeScoringFnParams, + RegexParserScoringFnParams, + BasicScoringFnParams, ], - name="ScoringFnParams", -) + Field(discriminator="type"), +] +register_schema(ScoringFnParams, name="ScoringFnParams") class CommonScoringFnFields(BaseModel): diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index cbea57e79..d57c311b2 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel): status: SpanStatus -StructuredLogPayload = register_schema( - Annotated[ - Union[ - SpanStartPayload, - SpanEndPayload, - ], - Field(discriminator="type"), +StructuredLogPayload = Annotated[ + Union[ + SpanStartPayload, + SpanEndPayload, ], - name="StructuredLogPayload", -) + Field(discriminator="type"), +] +register_schema(StructuredLogPayload, name="StructuredLogPayload") @json_schema_type @@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon): payload: StructuredLogPayload -Event = register_schema( - Annotated[ - Union[ - UnstructuredLogEvent, - MetricEvent, - StructuredLogEvent, - ], - Field(discriminator="type"), +Event = Annotated[ + Union[ + UnstructuredLogEvent, + MetricEvent, + StructuredLogEvent, ], - name="Event", -) + Field(discriminator="type"), +] +register_schema(Event, name="Event") @json_schema_type diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 671e19619..73b36e050 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -58,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel): template: str -RAGQueryGeneratorConfig = register_schema( - Annotated[ - Union[ - DefaultRAGQueryGeneratorConfig, - LLMRAGQueryGeneratorConfig, - ], - Field(discriminator="type"), +RAGQueryGeneratorConfig = Annotated[ + Union[ + DefaultRAGQueryGeneratorConfig, + LLMRAGQueryGeneratorConfig, ], - name="RAGQueryGeneratorConfig", -) + Field(discriminator="type"), +] +register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig") @json_schema_type diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index 9842d7980..f762eb50f 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -186,13 +186,11 @@ class TopKSamplingStrategy(BaseModel): top_k: int = Field(..., ge=1) -SamplingStrategy = register_schema( - Annotated[ - Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], - Field(discriminator="type"), - ], - name="SamplingStrategy", -) +SamplingStrategy = Annotated[ + Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], + Field(discriminator="type"), +] +register_schema(SamplingStrategy, name="SamplingStrategy") @json_schema_type