forked from phoenix-oss/llama-stack-mirror
chore: Don't set type variables from register_schema() (#1713)
# What does this PR do? Don't set type variables from register_schema(). `mypy` is not happy about it since type variables are calculated at runtime and hence the typing hints are not available during static analysis. Good news is there is no good reason to set the variables from the return type. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com> Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
a483a58c6e
commit
41bd350539
11 changed files with 101 additions and 133 deletions
|
@ -189,13 +189,11 @@ class AgentToolGroupWithArgs(BaseModel):
|
||||||
args: Dict[str, Any]
|
args: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
AgentToolGroup = register_schema(
|
AgentToolGroup = Union[
|
||||||
Union[
|
str,
|
||||||
str,
|
AgentToolGroupWithArgs,
|
||||||
AgentToolGroupWithArgs,
|
]
|
||||||
],
|
register_schema(AgentToolGroup, name="AgentTool")
|
||||||
name="AgentTool",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
|
@ -312,20 +310,18 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = register_schema(
|
AgentTurnResponseEventPayload = Annotated[
|
||||||
Annotated[
|
Union[
|
||||||
Union[
|
AgentTurnResponseStepStartPayload,
|
||||||
AgentTurnResponseStepStartPayload,
|
AgentTurnResponseStepProgressPayload,
|
||||||
AgentTurnResponseStepProgressPayload,
|
AgentTurnResponseStepCompletePayload,
|
||||||
AgentTurnResponseStepCompletePayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
AgentTurnResponseTurnAwaitingInputPayload,
|
|
||||||
],
|
|
||||||
Field(discriminator="event_type"),
|
|
||||||
],
|
],
|
||||||
name="AgentTurnResponseEventPayload",
|
Field(discriminator="event_type"),
|
||||||
)
|
]
|
||||||
|
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
# other modalities can be added here
|
# other modalities can be added here
|
||||||
InterleavedContentItem = register_schema(
|
InterleavedContentItem = Annotated[
|
||||||
Annotated[
|
Union[ImageContentItem, TextContentItem],
|
||||||
Union[ImageContentItem, TextContentItem],
|
Field(discriminator="type"),
|
||||||
Field(discriminator="type"),
|
]
|
||||||
],
|
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||||
name="InterleavedContentItem",
|
|
||||||
)
|
|
||||||
|
|
||||||
# accept a single "str" as a special case since it is common
|
# accept a single "str" as a special case since it is common
|
||||||
InterleavedContent = register_schema(
|
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
||||||
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
|
register_schema(InterleavedContent, name="InterleavedContent")
|
||||||
name="InterleavedContent",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
# streaming completions send a stream of ContentDeltas
|
# streaming completions send a stream of ContentDeltas
|
||||||
ContentDelta = register_schema(
|
ContentDelta = Annotated[
|
||||||
Annotated[
|
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
Field(discriminator="type"),
|
||||||
Field(discriminator="type"),
|
]
|
||||||
],
|
register_schema(ContentDelta, name="ContentDelta")
|
||||||
name="ContentDelta",
|
|
||||||
)
|
|
||||||
|
|
|
@ -72,24 +72,22 @@ class DialogType(BaseModel):
|
||||||
type: Literal["dialog"] = "dialog"
|
type: Literal["dialog"] = "dialog"
|
||||||
|
|
||||||
|
|
||||||
ParamType = register_schema(
|
ParamType = Annotated[
|
||||||
Annotated[
|
Union[
|
||||||
Union[
|
StringType,
|
||||||
StringType,
|
NumberType,
|
||||||
NumberType,
|
BooleanType,
|
||||||
BooleanType,
|
ArrayType,
|
||||||
ArrayType,
|
ObjectType,
|
||||||
ObjectType,
|
JsonType,
|
||||||
JsonType,
|
UnionType,
|
||||||
UnionType,
|
ChatCompletionInputType,
|
||||||
ChatCompletionInputType,
|
CompletionInputType,
|
||||||
CompletionInputType,
|
AgentTurnInputType,
|
||||||
AgentTurnInputType,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
],
|
],
|
||||||
name="ParamType",
|
Field(discriminator="type"),
|
||||||
)
|
]
|
||||||
|
register_schema(ParamType, name="ParamType")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# TODO: recursive definition of ParamType in these containers
|
# TODO: recursive definition of ParamType in these containers
|
||||||
|
|
|
@ -84,13 +84,11 @@ class RowsDataSource(BaseModel):
|
||||||
rows: List[Dict[str, Any]]
|
rows: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
DataSource = register_schema(
|
DataSource = Annotated[
|
||||||
Annotated[
|
Union[URIDataSource, RowsDataSource],
|
||||||
Union[URIDataSource, RowsDataSource],
|
Field(discriminator="type"),
|
||||||
Field(discriminator="type"),
|
]
|
||||||
],
|
register_schema(DataSource, name="DataSource")
|
||||||
name="DataSource",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CommonDatasetFields(BaseModel):
|
class CommonDatasetFields(BaseModel):
|
||||||
|
|
|
@ -43,10 +43,8 @@ class AgentCandidate(BaseModel):
|
||||||
config: AgentConfig
|
config: AgentConfig
|
||||||
|
|
||||||
|
|
||||||
EvalCandidate = register_schema(
|
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
||||||
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
|
register_schema(EvalCandidate, name="EvalCandidate")
|
||||||
name="EvalCandidate",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -144,18 +144,16 @@ class CompletionMessage(BaseModel):
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
Message = register_schema(
|
Message = Annotated[
|
||||||
Annotated[
|
Union[
|
||||||
Union[
|
UserMessage,
|
||||||
UserMessage,
|
SystemMessage,
|
||||||
SystemMessage,
|
ToolResponseMessage,
|
||||||
ToolResponseMessage,
|
CompletionMessage,
|
||||||
CompletionMessage,
|
|
||||||
],
|
|
||||||
Field(discriminator="role"),
|
|
||||||
],
|
],
|
||||||
name="Message",
|
Field(discriminator="role"),
|
||||||
)
|
]
|
||||||
|
register_schema(Message, name="Message")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
|
||||||
bnf: Dict[str, Any]
|
bnf: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
ResponseFormat = register_schema(
|
ResponseFormat = Annotated[
|
||||||
Annotated[
|
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
Field(discriminator="type"),
|
||||||
Field(discriminator="type"),
|
]
|
||||||
],
|
register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
name="ResponseFormat",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# This is an internally used class
|
# This is an internally used class
|
||||||
|
|
|
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
|
||||||
group_size: int
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = register_schema(
|
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||||
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
|
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||||
name="AlgorithmConfig",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -79,17 +79,15 @@ class BasicScoringFnParams(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ScoringFnParams = register_schema(
|
ScoringFnParams = Annotated[
|
||||||
Annotated[
|
Union[
|
||||||
Union[
|
LLMAsJudgeScoringFnParams,
|
||||||
LLMAsJudgeScoringFnParams,
|
RegexParserScoringFnParams,
|
||||||
RegexParserScoringFnParams,
|
BasicScoringFnParams,
|
||||||
BasicScoringFnParams,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
],
|
],
|
||||||
name="ScoringFnParams",
|
Field(discriminator="type"),
|
||||||
)
|
]
|
||||||
|
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||||
|
|
||||||
|
|
||||||
class CommonScoringFnFields(BaseModel):
|
class CommonScoringFnFields(BaseModel):
|
||||||
|
|
|
@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel):
|
||||||
status: SpanStatus
|
status: SpanStatus
|
||||||
|
|
||||||
|
|
||||||
StructuredLogPayload = register_schema(
|
StructuredLogPayload = Annotated[
|
||||||
Annotated[
|
Union[
|
||||||
Union[
|
SpanStartPayload,
|
||||||
SpanStartPayload,
|
SpanEndPayload,
|
||||||
SpanEndPayload,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
],
|
],
|
||||||
name="StructuredLogPayload",
|
Field(discriminator="type"),
|
||||||
)
|
]
|
||||||
|
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon):
|
||||||
payload: StructuredLogPayload
|
payload: StructuredLogPayload
|
||||||
|
|
||||||
|
|
||||||
Event = register_schema(
|
Event = Annotated[
|
||||||
Annotated[
|
Union[
|
||||||
Union[
|
UnstructuredLogEvent,
|
||||||
UnstructuredLogEvent,
|
MetricEvent,
|
||||||
MetricEvent,
|
StructuredLogEvent,
|
||||||
StructuredLogEvent,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
],
|
],
|
||||||
name="Event",
|
Field(discriminator="type"),
|
||||||
)
|
]
|
||||||
|
register_schema(Event, name="Event")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -58,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||||
template: str
|
template: str
|
||||||
|
|
||||||
|
|
||||||
RAGQueryGeneratorConfig = register_schema(
|
RAGQueryGeneratorConfig = Annotated[
|
||||||
Annotated[
|
Union[
|
||||||
Union[
|
DefaultRAGQueryGeneratorConfig,
|
||||||
DefaultRAGQueryGeneratorConfig,
|
LLMRAGQueryGeneratorConfig,
|
||||||
LLMRAGQueryGeneratorConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
],
|
],
|
||||||
name="RAGQueryGeneratorConfig",
|
Field(discriminator="type"),
|
||||||
)
|
]
|
||||||
|
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -186,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
|
||||||
top_k: int = Field(..., ge=1)
|
top_k: int = Field(..., ge=1)
|
||||||
|
|
||||||
|
|
||||||
SamplingStrategy = register_schema(
|
SamplingStrategy = Annotated[
|
||||||
Annotated[
|
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
Field(discriminator="type"),
|
||||||
Field(discriminator="type"),
|
]
|
||||||
],
|
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||||
name="SamplingStrategy",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue