From 41bd3505399b9b909270539eaecf063e0215eff1 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Wed, 19 Mar 2025 23:29:00 -0400 Subject: [PATCH] chore: Don't set type variables from register_schema() (#1713) # What does this PR do? Don't set type variables from register_schema(). `mypy` is not happy about it since type variables are calculated at runtime and hence the typing hints are not available during static analysis. Good news is there is no good reason to set the variables from the return type. Signed-off-by: Ihar Hrachyshka Signed-off-by: Ihar Hrachyshka --- llama_stack/apis/agents/agents.py | 36 +++++++++---------- llama_stack/apis/common/content_types.py | 30 +++++++--------- llama_stack/apis/common/type_system.py | 32 ++++++++--------- llama_stack/apis/datasets/datasets.py | 12 +++---- llama_stack/apis/eval/eval.py | 6 ++-- llama_stack/apis/inference/inference.py | 32 ++++++++--------- .../apis/post_training/post_training.py | 6 ++-- .../scoring_functions/scoring_functions.py | 18 +++++----- llama_stack/apis/telemetry/telemetry.py | 34 ++++++++---------- llama_stack/apis/tools/rag_tool.py | 16 ++++----- llama_stack/models/llama/datatypes.py | 12 +++---- 11 files changed, 101 insertions(+), 133 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5cc910a55..75f0dddd1 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -189,13 +189,11 @@ class AgentToolGroupWithArgs(BaseModel): args: Dict[str, Any] -AgentToolGroup = register_schema( - Union[ - str, - AgentToolGroupWithArgs, - ], - name="AgentTool", -) +AgentToolGroup = Union[ + str, + AgentToolGroupWithArgs, +] +register_schema(AgentToolGroup, name="AgentTool") class AgentConfigCommon(BaseModel): @@ -312,20 +310,18 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): turn: Turn -AgentTurnResponseEventPayload = register_schema( - Annotated[ - Union[ - AgentTurnResponseStepStartPayload, - AgentTurnResponseStepProgressPayload, - AgentTurnResponseStepCompletePayload, - AgentTurnResponseTurnStartPayload, - AgentTurnResponseTurnCompletePayload, - AgentTurnResponseTurnAwaitingInputPayload, - ], - Field(discriminator="event_type"), +AgentTurnResponseEventPayload = Annotated[ + Union[ + AgentTurnResponseStepStartPayload, + AgentTurnResponseStepProgressPayload, + AgentTurnResponseStepCompletePayload, + AgentTurnResponseTurnStartPayload, + AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnAwaitingInputPayload, ], - name="AgentTurnResponseEventPayload", -) + Field(discriminator="event_type"), +] +register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload") @json_schema_type diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 0d0afa894..9d4e21308 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -63,19 +63,15 @@ class TextContentItem(BaseModel): # other modalities can be added here -InterleavedContentItem = register_schema( - Annotated[ - Union[ImageContentItem, TextContentItem], - Field(discriminator="type"), - ], - name="InterleavedContentItem", -) +InterleavedContentItem = Annotated[ + Union[ImageContentItem, TextContentItem], + Field(discriminator="type"), +] +register_schema(InterleavedContentItem, name="InterleavedContentItem") # accept a single "str" as a special case since it is common -InterleavedContent = register_schema( - Union[str, InterleavedContentItem, List[InterleavedContentItem]], - name="InterleavedContent", -) +InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]] +register_schema(InterleavedContent, name="InterleavedContent") @json_schema_type @@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel): # streaming completions send a stream of ContentDeltas -ContentDelta = register_schema( - Annotated[ - Union[TextDelta, ImageDelta, ToolCallDelta], - Field(discriminator="type"), - ], - name="ContentDelta", -) +ContentDelta = Annotated[ + Union[TextDelta, ImageDelta, ToolCallDelta], + Field(discriminator="type"), +] +register_schema(ContentDelta, name="ContentDelta") diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index d7746df8d..5d9f000be 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -72,24 +72,22 @@ class DialogType(BaseModel): type: Literal["dialog"] = "dialog" -ParamType = register_schema( - Annotated[ - Union[ - StringType, - NumberType, - BooleanType, - ArrayType, - ObjectType, - JsonType, - UnionType, - ChatCompletionInputType, - CompletionInputType, - AgentTurnInputType, - ], - Field(discriminator="type"), +ParamType = Annotated[ + Union[ + StringType, + NumberType, + BooleanType, + ArrayType, + ObjectType, + JsonType, + UnionType, + ChatCompletionInputType, + CompletionInputType, + AgentTurnInputType, ], - name="ParamType", -) + Field(discriminator="type"), +] +register_schema(ParamType, name="ParamType") """ # TODO: recursive definition of ParamType in these containers diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index e2c940f64..32ccde144 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -84,13 +84,11 @@ class RowsDataSource(BaseModel): rows: List[Dict[str, Any]] -DataSource = register_schema( - Annotated[ - Union[URIDataSource, RowsDataSource], - Field(discriminator="type"), - ], - name="DataSource", -) +DataSource = Annotated[ + Union[URIDataSource, RowsDataSource], + Field(discriminator="type"), +] +register_schema(DataSource, name="DataSource") class CommonDatasetFields(BaseModel): diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 51c38b16a..d05786321 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -43,10 +43,8 @@ class AgentCandidate(BaseModel): config: AgentConfig -EvalCandidate = register_schema( - Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")], - name="EvalCandidate", -) +EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")] +register_schema(EvalCandidate, name="EvalCandidate") @json_schema_type diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 0a4324cdf..7d3539dcb 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -144,18 +144,16 @@ class CompletionMessage(BaseModel): tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) -Message = register_schema( - Annotated[ - Union[ - UserMessage, - SystemMessage, - ToolResponseMessage, - CompletionMessage, - ], - Field(discriminator="role"), +Message = Annotated[ + Union[ + UserMessage, + SystemMessage, + ToolResponseMessage, + CompletionMessage, ], - name="Message", -) + Field(discriminator="role"), +] +register_schema(Message, name="Message") @json_schema_type @@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel): bnf: Dict[str, Any] -ResponseFormat = register_schema( - Annotated[ - Union[JsonSchemaResponseFormat, GrammarResponseFormat], - Field(discriminator="type"), - ], - name="ResponseFormat", -) +ResponseFormat = Annotated[ + Union[JsonSchemaResponseFormat, GrammarResponseFormat], + Field(discriminator="type"), +] +register_schema(ResponseFormat, name="ResponseFormat") # This is an internally used class diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 362f87a26..e61c0e4e4 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel): group_size: int -AlgorithmConfig = register_schema( - Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")], - name="AlgorithmConfig", -) +AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")] +register_schema(AlgorithmConfig, name="AlgorithmConfig") @json_schema_type diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 57761c940..4f85947dd 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -79,17 +79,15 @@ class BasicScoringFnParams(BaseModel): ) -ScoringFnParams = register_schema( - Annotated[ - Union[ - LLMAsJudgeScoringFnParams, - RegexParserScoringFnParams, - BasicScoringFnParams, - ], - Field(discriminator="type"), +ScoringFnParams = Annotated[ + Union[ + LLMAsJudgeScoringFnParams, + RegexParserScoringFnParams, + BasicScoringFnParams, ], - name="ScoringFnParams", -) + Field(discriminator="type"), +] +register_schema(ScoringFnParams, name="ScoringFnParams") class CommonScoringFnFields(BaseModel): diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index cbea57e79..d57c311b2 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel): status: SpanStatus -StructuredLogPayload = register_schema( - Annotated[ - Union[ - SpanStartPayload, - SpanEndPayload, - ], - Field(discriminator="type"), +StructuredLogPayload = Annotated[ + Union[ + SpanStartPayload, + SpanEndPayload, ], - name="StructuredLogPayload", -) + Field(discriminator="type"), +] +register_schema(StructuredLogPayload, name="StructuredLogPayload") @json_schema_type @@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon): payload: StructuredLogPayload -Event = register_schema( - Annotated[ - Union[ - UnstructuredLogEvent, - MetricEvent, - StructuredLogEvent, - ], - Field(discriminator="type"), +Event = Annotated[ + Union[ + UnstructuredLogEvent, + MetricEvent, + StructuredLogEvent, ], - name="Event", -) + Field(discriminator="type"), +] +register_schema(Event, name="Event") @json_schema_type diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 671e19619..73b36e050 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -58,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel): template: str -RAGQueryGeneratorConfig = register_schema( - Annotated[ - Union[ - DefaultRAGQueryGeneratorConfig, - LLMRAGQueryGeneratorConfig, - ], - Field(discriminator="type"), +RAGQueryGeneratorConfig = Annotated[ + Union[ + DefaultRAGQueryGeneratorConfig, + LLMRAGQueryGeneratorConfig, ], - name="RAGQueryGeneratorConfig", -) + Field(discriminator="type"), +] +register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig") @json_schema_type diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index 9842d7980..f762eb50f 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -186,13 +186,11 @@ class TopKSamplingStrategy(BaseModel): top_k: int = Field(..., ge=1) -SamplingStrategy = register_schema( - Annotated[ - Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], - Field(discriminator="type"), - ], - name="SamplingStrategy", -) +SamplingStrategy = Annotated[ + Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], + Field(discriminator="type"), +] +register_schema(SamplingStrategy, name="SamplingStrategy") @json_schema_type