Update OpenAPI generator to output discriminator (#848)

oneOf should have discriminators so Stainless can generate better code

## Test Plan

Going to generate the SDK now and check.
This commit is contained in:
Ashwin Bharambe 2025-01-22 22:15:23 -08:00 committed by GitHub
parent 65f07c3d63
commit 35c71d5bbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 159 additions and 35 deletions

View file

@ -2088,7 +2088,7 @@
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"from llama_stack_client.types.tool_runtime import DocumentParam as Document\n", "from llama_stack_client.types import Document\n",
"\n", "\n",
"urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n", "urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n",
"documents = [\n", "documents = [\n",

View file

@ -125,6 +125,7 @@ class JsonSchemaAnyOf(JsonSchemaNode):
@dataclass @dataclass
class JsonSchemaOneOf(JsonSchemaNode): class JsonSchemaOneOf(JsonSchemaNode):
oneOf: List["JsonSchemaAny"] oneOf: List["JsonSchemaAny"]
discriminator: Optional[str]
JsonSchemaAny = Union[ JsonSchemaAny = Union[

View file

@ -36,6 +36,7 @@ from typing import (
) )
import jsonschema import jsonschema
from typing_extensions import Annotated
from . import docstring from . import docstring
from .auxiliary import ( from .auxiliary import (
@ -329,7 +330,6 @@ class JsonSchemaGenerator:
if metadata is not None: if metadata is not None:
# type is Annotated[T, ...] # type is Annotated[T, ...]
typ = typing.get_args(data_type)[0] typ = typing.get_args(data_type)[0]
schema = self._simple_type_to_schema(typ) schema = self._simple_type_to_schema(typ)
if schema is not None: if schema is not None:
# recognize well-known auxiliary types # recognize well-known auxiliary types
@ -446,12 +446,20 @@ class JsonSchemaGenerator:
], ],
} }
elif origin_type is Union: elif origin_type is Union:
return { discriminator = None
if typing.get_origin(data_type) is Annotated:
discriminator = typing.get_args(data_type)[1].discriminator
ret = {
"oneOf": [ "oneOf": [
self.type_to_schema(union_type) self.type_to_schema(union_type)
for union_type in typing.get_args(typ) for union_type in typing.get_args(typ)
] ]
} }
if discriminator:
ret["discriminator"] = {
"propertyName": discriminator,
}
return ret
elif origin_type is Literal: elif origin_type is Literal:
(literal_value,) = typing.get_args(typ) # unpack value of literal type (literal_value,) = typing.get_args(typ) # unpack value of literal type
schema = self.type_to_schema(type(literal_value)) schema = self.type_to_schema(type(literal_value))

View file

@ -3810,7 +3810,10 @@
{ {
"$ref": "#/components/schemas/TextContentItem" "$ref": "#/components/schemas/TextContentItem"
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
"Message": { "Message": {
"oneOf": [ "oneOf": [
@ -3826,7 +3829,10 @@
{ {
"$ref": "#/components/schemas/CompletionMessage" "$ref": "#/components/schemas/CompletionMessage"
} }
] ],
"discriminator": {
"propertyName": "role"
}
}, },
"SamplingParams": { "SamplingParams": {
"type": "object", "type": "object",
@ -3842,7 +3848,10 @@
{ {
"$ref": "#/components/schemas/TopKSamplingStrategy" "$ref": "#/components/schemas/TopKSamplingStrategy"
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
"max_tokens": { "max_tokens": {
"type": "integer", "type": "integer",
@ -4386,7 +4395,10 @@
"bnf" "bnf"
] ]
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
"ChatCompletionRequest": { "ChatCompletionRequest": {
"type": "object", "type": "object",
@ -4515,7 +4527,10 @@
{ {
"$ref": "#/components/schemas/ToolCallDelta" "$ref": "#/components/schemas/ToolCallDelta"
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
"ImageDelta": { "ImageDelta": {
"type": "object", "type": "object",
@ -5019,7 +5034,10 @@
{ {
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload" "$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
} }
] ],
"discriminator": {
"propertyName": "event_type"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -5062,7 +5080,10 @@
{ {
"$ref": "#/components/schemas/MemoryRetrievalStep" "$ref": "#/components/schemas/MemoryRetrievalStep"
} }
] ],
"discriminator": {
"propertyName": "step_type"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -5462,7 +5483,10 @@
{ {
"$ref": "#/components/schemas/MemoryRetrievalStep" "$ref": "#/components/schemas/MemoryRetrievalStep"
} }
] ],
"discriminator": {
"propertyName": "step_type"
}
} }
}, },
"output_message": { "output_message": {
@ -5612,7 +5636,10 @@
{ {
"$ref": "#/components/schemas/AgentCandidate" "$ref": "#/components/schemas/AgentCandidate"
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
"scoring_params": { "scoring_params": {
"type": "object", "type": "object",
@ -5627,7 +5654,10 @@
{ {
"$ref": "#/components/schemas/BasicScoringFnParams" "$ref": "#/components/schemas/BasicScoringFnParams"
} }
] ],
"discriminator": {
"propertyName": "type"
}
} }
}, },
"num_examples": { "num_examples": {
@ -5677,7 +5707,10 @@
{ {
"$ref": "#/components/schemas/AgentCandidate" "$ref": "#/components/schemas/AgentCandidate"
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
"num_examples": { "num_examples": {
"type": "integer" "type": "integer"
@ -5818,7 +5851,10 @@
{ {
"$ref": "#/components/schemas/AppEvalTaskConfig" "$ref": "#/components/schemas/AppEvalTaskConfig"
} }
] ],
"discriminator": {
"propertyName": "type"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -5981,7 +6017,10 @@
{ {
"$ref": "#/components/schemas/MemoryRetrievalStep" "$ref": "#/components/schemas/MemoryRetrievalStep"
} }
] ],
"discriminator": {
"propertyName": "step_type"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -6196,7 +6235,10 @@
{ {
"$ref": "#/components/schemas/AgentTurnInputType" "$ref": "#/components/schemas/AgentTurnInputType"
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
"StringType": { "StringType": {
"type": "object", "type": "object",
@ -6456,7 +6498,10 @@
{ {
"$ref": "#/components/schemas/BasicScoringFnParams" "$ref": "#/components/schemas/BasicScoringFnParams"
} }
] ],
"discriminator": {
"propertyName": "type"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -7542,7 +7587,10 @@
{ {
"$ref": "#/components/schemas/SpanEndPayload" "$ref": "#/components/schemas/SpanEndPayload"
} }
] ],
"discriminator": {
"propertyName": "type"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -7628,7 +7676,10 @@
{ {
"$ref": "#/components/schemas/StructuredLogEvent" "$ref": "#/components/schemas/StructuredLogEvent"
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
"ttl_seconds": { "ttl_seconds": {
"type": "integer" "type": "integer"
@ -7958,7 +8009,10 @@
{ {
"$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig" "$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
"QueryRequest": { "QueryRequest": {
"type": "object", "type": "object",
@ -8350,7 +8404,10 @@
{ {
"$ref": "#/components/schemas/BasicScoringFnParams" "$ref": "#/components/schemas/BasicScoringFnParams"
} }
] ],
"discriminator": {
"propertyName": "type"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -8483,7 +8540,10 @@
{ {
"$ref": "#/components/schemas/AppEvalTaskConfig" "$ref": "#/components/schemas/AppEvalTaskConfig"
} }
] ],
"discriminator": {
"propertyName": "type"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -8632,7 +8692,10 @@
{ {
"$ref": "#/components/schemas/BasicScoringFnParams" "$ref": "#/components/schemas/BasicScoringFnParams"
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
{ {
"type": "null" "type": "null"
@ -8683,7 +8746,10 @@
{ {
"$ref": "#/components/schemas/BasicScoringFnParams" "$ref": "#/components/schemas/BasicScoringFnParams"
} }
] ],
"discriminator": {
"propertyName": "type"
}
}, },
{ {
"type": "null" "type": "null"
@ -8860,7 +8926,10 @@
{ {
"$ref": "#/components/schemas/QATFinetuningConfig" "$ref": "#/components/schemas/QATFinetuningConfig"
} }
] ],
"discriminator": {
"propertyName": "type"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -76,6 +76,8 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
step: step:
discriminator:
propertyName: step_type
oneOf: oneOf:
- $ref: '#/components/schemas/InferenceStep' - $ref: '#/components/schemas/InferenceStep'
- $ref: '#/components/schemas/ToolExecutionStep' - $ref: '#/components/schemas/ToolExecutionStep'
@ -119,6 +121,8 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
payload: payload:
discriminator:
propertyName: event_type
oneOf: oneOf:
- $ref: '#/components/schemas/AgentTurnResponseStepStartPayload' - $ref: '#/components/schemas/AgentTurnResponseStepStartPayload'
- $ref: '#/components/schemas/AgentTurnResponseStepProgressPayload' - $ref: '#/components/schemas/AgentTurnResponseStepProgressPayload'
@ -137,6 +141,8 @@ components:
default: step_complete default: step_complete
type: string type: string
step_details: step_details:
discriminator:
propertyName: step_type
oneOf: oneOf:
- $ref: '#/components/schemas/InferenceStep' - $ref: '#/components/schemas/InferenceStep'
- $ref: '#/components/schemas/ToolExecutionStep' - $ref: '#/components/schemas/ToolExecutionStep'
@ -258,6 +264,8 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
eval_candidate: eval_candidate:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/ModelCandidate' - $ref: '#/components/schemas/ModelCandidate'
- $ref: '#/components/schemas/AgentCandidate' - $ref: '#/components/schemas/AgentCandidate'
@ -265,6 +273,8 @@ components:
type: integer type: integer
scoring_params: scoring_params:
additionalProperties: additionalProperties:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams'
@ -402,6 +412,8 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
eval_candidate: eval_candidate:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/ModelCandidate' - $ref: '#/components/schemas/ModelCandidate'
- $ref: '#/components/schemas/AgentCandidate' - $ref: '#/components/schemas/AgentCandidate'
@ -619,6 +631,8 @@ components:
title: streamed completion response. title: streamed completion response.
type: object type: object
ContentDelta: ContentDelta:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/TextDelta' - $ref: '#/components/schemas/TextDelta'
- $ref: '#/components/schemas/ImageDelta' - $ref: '#/components/schemas/ImageDelta'
@ -897,6 +911,8 @@ components:
type: string type: string
type: array type: array
task_config: task_config:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/BenchmarkEvalTaskConfig' - $ref: '#/components/schemas/BenchmarkEvalTaskConfig'
- $ref: '#/components/schemas/AppEvalTaskConfig' - $ref: '#/components/schemas/AppEvalTaskConfig'
@ -1038,6 +1054,8 @@ components:
$ref: '#/components/schemas/InterleavedContentItem' $ref: '#/components/schemas/InterleavedContentItem'
type: array type: array
InterleavedContentItem: InterleavedContentItem:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/ImageContentItem' - $ref: '#/components/schemas/ImageContentItem'
- $ref: '#/components/schemas/TextContentItem' - $ref: '#/components/schemas/TextContentItem'
@ -1244,6 +1262,8 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
event: event:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/UnstructuredLogEvent' - $ref: '#/components/schemas/UnstructuredLogEvent'
- $ref: '#/components/schemas/MetricEvent' - $ref: '#/components/schemas/MetricEvent'
@ -1325,6 +1345,8 @@ components:
- inserted_context - inserted_context
type: object type: object
Message: Message:
discriminator:
propertyName: role
oneOf: oneOf:
- $ref: '#/components/schemas/UserMessage' - $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage' - $ref: '#/components/schemas/SystemMessage'
@ -1495,6 +1517,8 @@ components:
- total_count - total_count
type: object type: object
ParamType: ParamType:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/StringType' - $ref: '#/components/schemas/StringType'
- $ref: '#/components/schemas/NumberType' - $ref: '#/components/schemas/NumberType'
@ -1805,6 +1829,8 @@ components:
- max_chunks - max_chunks
type: object type: object
RAGQueryGeneratorConfig: RAGQueryGeneratorConfig:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig' - $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig' - $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
@ -1922,6 +1948,8 @@ components:
description: description:
type: string type: string
params: params:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams'
@ -2002,6 +2030,8 @@ components:
- embedding_model - embedding_model
type: object type: object
ResponseFormat: ResponseFormat:
discriminator:
propertyName: type
oneOf: oneOf:
- additionalProperties: false - additionalProperties: false
properties: properties:
@ -2063,6 +2093,8 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
task_config: task_config:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/BenchmarkEvalTaskConfig' - $ref: '#/components/schemas/BenchmarkEvalTaskConfig'
- $ref: '#/components/schemas/AppEvalTaskConfig' - $ref: '#/components/schemas/AppEvalTaskConfig'
@ -2130,6 +2162,8 @@ components:
default: 1.0 default: 1.0
type: number type: number
strategy: strategy:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/GreedySamplingStrategy' - $ref: '#/components/schemas/GreedySamplingStrategy'
- $ref: '#/components/schemas/TopPSamplingStrategy' - $ref: '#/components/schemas/TopPSamplingStrategy'
@ -2167,7 +2201,9 @@ components:
scoring_functions: scoring_functions:
additionalProperties: additionalProperties:
oneOf: oneOf:
- oneOf: - discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams'
- $ref: '#/components/schemas/BasicScoringFnParams' - $ref: '#/components/schemas/BasicScoringFnParams'
@ -2208,7 +2244,9 @@ components:
scoring_functions: scoring_functions:
additionalProperties: additionalProperties:
oneOf: oneOf:
- oneOf: - discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams'
- $ref: '#/components/schemas/BasicScoringFnParams' - $ref: '#/components/schemas/BasicScoringFnParams'
@ -2246,6 +2284,8 @@ components:
- type: object - type: object
type: object type: object
params: params:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams'
@ -2503,6 +2543,8 @@ components:
- type: object - type: object
type: object type: object
payload: payload:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/SpanStartPayload' - $ref: '#/components/schemas/SpanStartPayload'
- $ref: '#/components/schemas/SpanEndPayload' - $ref: '#/components/schemas/SpanEndPayload'
@ -2528,6 +2570,8 @@ components:
additionalProperties: false additionalProperties: false
properties: properties:
algorithm_config: algorithm_config:
discriminator:
propertyName: type
oneOf: oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig' - $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig' - $ref: '#/components/schemas/QATFinetuningConfig'
@ -3115,6 +3159,8 @@ components:
type: string type: string
steps: steps:
items: items:
discriminator:
propertyName: step_type
oneOf: oneOf:
- $ref: '#/components/schemas/InferenceStep' - $ref: '#/components/schemas/InferenceStep'
- $ref: '#/components/schemas/ToolExecutionStep' - $ref: '#/components/schemas/ToolExecutionStep'

View file

@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v6" KEY_VERSION = "v7"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -8,7 +8,7 @@ import random
import pytest import pytest
from llama_stack_client.types.tool_runtime import DocumentParam from llama_stack_client.types import Document
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
@ -38,22 +38,22 @@ def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def sample_documents(): def sample_documents():
return [ return [
DocumentParam( Document(
document_id="test-doc-1", document_id="test-doc-1",
content="Python is a high-level programming language.", content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"}, metadata={"category": "programming", "difficulty": "beginner"},
), ),
DocumentParam( Document(
document_id="test-doc-2", document_id="test-doc-2",
content="Machine learning is a subset of artificial intelligence.", content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"}, metadata={"category": "AI", "difficulty": "advanced"},
), ),
DocumentParam( Document(
document_id="test-doc-3", document_id="test-doc-3",
content="Data structures are fundamental to computer science.", content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"}, metadata={"category": "computer science", "difficulty": "intermediate"},
), ),
DocumentParam( Document(
document_id="test-doc-4", document_id="test-doc-4",
content="Neural networks are inspired by biological neural networks.", content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"}, metadata={"category": "AI", "difficulty": "advanced"},
@ -148,7 +148,7 @@ def test_vector_db_insert_from_url_and_query(
"llama3.rst", "llama3.rst",
] ]
documents = [ documents = [
DocumentParam( Document(
document_id=f"num-{i}", document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain", mime_type="text/plain",