forked from phoenix-oss/llama-stack-mirror
Update discriminator to have the correct mapping
(#881)
See https://swagger.io/docs/specification/v3_0/data-models/inheritance-and-polymorphism/#discriminator When specifying discriminators, mapping must be specified unless the value of the discriminator is the subtype itself (which in our case is not.) The changes in the YAML are self-explanatory.
This commit is contained in:
parent
a6d20e0f53
commit
e5936a8df8
10 changed files with 642 additions and 420 deletions
|
@ -229,11 +229,8 @@ class AgentTurnResponseTurnCompletePayload(BaseModel):
|
|||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseEvent(BaseModel):
|
||||
"""Streamed agent execution response."""
|
||||
|
||||
payload: Annotated[
|
||||
AgentTurnResponseEventPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
|
@ -242,7 +239,14 @@ class AgentTurnResponseEvent(BaseModel):
|
|||
AgentTurnResponseTurnCompletePayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
],
|
||||
name="AgentTurnResponseEventPayload",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseEvent(BaseModel):
|
||||
payload: AgentTurnResponseEventPayload
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
@ -31,9 +31,10 @@ class AgentCandidate(BaseModel):
|
|||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = Annotated[
|
||||
Union[ModelCandidate, AgentCandidate], Field(discriminator="type")
|
||||
]
|
||||
EvalCandidate = register_schema(
|
||||
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
|
||||
name="EvalCandidate",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -61,9 +62,12 @@ class AppEvalTaskConfig(BaseModel):
|
|||
# we could optinally add any specific dataset config here
|
||||
|
||||
|
||||
EvalTaskConfig = Annotated[
|
||||
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
|
||||
]
|
||||
EvalTaskConfig = register_schema(
|
||||
Annotated[
|
||||
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
|
||||
],
|
||||
name="EvalTaskConfig",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -157,11 +157,13 @@ class ChatCompletionResponseEvent(BaseModel):
|
|||
stop_reason: Optional[StopReason] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ResponseFormatType(Enum):
|
||||
json_schema = "json_schema"
|
||||
grammar = "grammar"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class JsonSchemaResponseFormat(BaseModel):
|
||||
type: Literal[ResponseFormatType.json_schema.value] = (
|
||||
ResponseFormatType.json_schema.value
|
||||
|
@ -169,6 +171,7 @@ class JsonSchemaResponseFormat(BaseModel):
|
|||
json_schema: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GrammarResponseFormat(BaseModel):
|
||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
||||
bnf: Dict[str, Any]
|
||||
|
|
|
@ -8,7 +8,7 @@ from datetime import datetime
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
@ -88,9 +88,12 @@ class QATFinetuningConfig(BaseModel):
|
|||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = Annotated[
|
||||
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
|
||||
]
|
||||
AlgorithmConfig = register_schema(
|
||||
Annotated[
|
||||
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
|
||||
],
|
||||
name="AlgorithmConfig",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -16,7 +16,7 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
@ -82,14 +82,17 @@ class BasicScoringFnParams(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
ScoringFnParams = Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
ScoringFnParams = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
name="ScoringFnParams",
|
||||
)
|
||||
|
||||
|
||||
class CommonScoringFnFields(BaseModel):
|
||||
|
|
|
@ -17,7 +17,7 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
@ -115,13 +115,16 @@ class SpanEndPayload(BaseModel):
|
|||
status: SpanStatus
|
||||
|
||||
|
||||
StructuredLogPayload = Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
StructuredLogPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
name="StructuredLogPayload",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -130,14 +133,17 @@ class StructuredLogEvent(EventCommon):
|
|||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
Event = Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
Event = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
name="Event",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue