Update OpenAPI generator to output discriminator (#848)

oneOf should have discriminators so Stainless can generate better code

## Test Plan

Going to generate the SDK now and check.
This commit is contained in:
Ashwin Bharambe 2025-01-22 22:15:23 -08:00 committed by GitHub
parent 65f07c3d63
commit 35c71d5bbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 159 additions and 35 deletions

View file

@ -2088,7 +2088,7 @@
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n",
"from llama_stack_client.types.tool_runtime import DocumentParam as Document\n",
"from llama_stack_client.types import Document\n",
"\n",
"urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n",
"documents = [\n",

View file

@ -125,6 +125,7 @@ class JsonSchemaAnyOf(JsonSchemaNode):
@dataclass
class JsonSchemaOneOf(JsonSchemaNode):
oneOf: List["JsonSchemaAny"]
discriminator: Optional[str]
JsonSchemaAny = Union[

View file

@ -36,6 +36,7 @@ from typing import (
)
import jsonschema
from typing_extensions import Annotated
from . import docstring
from .auxiliary import (
@ -329,7 +330,6 @@ class JsonSchemaGenerator:
if metadata is not None:
# type is Annotated[T, ...]
typ = typing.get_args(data_type)[0]
schema = self._simple_type_to_schema(typ)
if schema is not None:
# recognize well-known auxiliary types
@ -446,12 +446,20 @@ class JsonSchemaGenerator:
],
}
elif origin_type is Union:
return {
discriminator = None
if typing.get_origin(data_type) is Annotated:
discriminator = typing.get_args(data_type)[1].discriminator
ret = {
"oneOf": [
self.type_to_schema(union_type)
for union_type in typing.get_args(typ)
]
}
if discriminator:
ret["discriminator"] = {
"propertyName": discriminator,
}
return ret
elif origin_type is Literal:
(literal_value,) = typing.get_args(typ) # unpack value of literal type
schema = self.type_to_schema(type(literal_value))

View file

@ -3810,7 +3810,10 @@
{
"$ref": "#/components/schemas/TextContentItem"
}
]
],
"discriminator": {
"propertyName": "type"
}
},
"Message": {
"oneOf": [
@ -3826,7 +3829,10 @@
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
],
"discriminator": {
"propertyName": "role"
}
},
"SamplingParams": {
"type": "object",
@ -3842,7 +3848,10 @@
{
"$ref": "#/components/schemas/TopKSamplingStrategy"
}
]
],
"discriminator": {
"propertyName": "type"
}
},
"max_tokens": {
"type": "integer",
@ -4386,7 +4395,10 @@
"bnf"
]
}
]
],
"discriminator": {
"propertyName": "type"
}
},
"ChatCompletionRequest": {
"type": "object",
@ -4515,7 +4527,10 @@
{
"$ref": "#/components/schemas/ToolCallDelta"
}
]
],
"discriminator": {
"propertyName": "type"
}
},
"ImageDelta": {
"type": "object",
@ -5019,7 +5034,10 @@
{
"$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload"
}
]
],
"discriminator": {
"propertyName": "event_type"
}
}
},
"additionalProperties": false,
@ -5062,7 +5080,10 @@
{
"$ref": "#/components/schemas/MemoryRetrievalStep"
}
]
],
"discriminator": {
"propertyName": "step_type"
}
}
},
"additionalProperties": false,
@ -5462,7 +5483,10 @@
{
"$ref": "#/components/schemas/MemoryRetrievalStep"
}
]
],
"discriminator": {
"propertyName": "step_type"
}
}
},
"output_message": {
@ -5612,7 +5636,10 @@
{
"$ref": "#/components/schemas/AgentCandidate"
}
]
],
"discriminator": {
"propertyName": "type"
}
},
"scoring_params": {
"type": "object",
@ -5627,7 +5654,10 @@
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
]
],
"discriminator": {
"propertyName": "type"
}
}
},
"num_examples": {
@ -5677,7 +5707,10 @@
{
"$ref": "#/components/schemas/AgentCandidate"
}
]
],
"discriminator": {
"propertyName": "type"
}
},
"num_examples": {
"type": "integer"
@ -5818,7 +5851,10 @@
{
"$ref": "#/components/schemas/AppEvalTaskConfig"
}
]
],
"discriminator": {
"propertyName": "type"
}
}
},
"additionalProperties": false,
@ -5981,7 +6017,10 @@
{
"$ref": "#/components/schemas/MemoryRetrievalStep"
}
]
],
"discriminator": {
"propertyName": "step_type"
}
}
},
"additionalProperties": false,
@ -6196,7 +6235,10 @@
{
"$ref": "#/components/schemas/AgentTurnInputType"
}
]
],
"discriminator": {
"propertyName": "type"
}
},
"StringType": {
"type": "object",
@ -6456,7 +6498,10 @@
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
]
],
"discriminator": {
"propertyName": "type"
}
}
},
"additionalProperties": false,
@ -7542,7 +7587,10 @@
{
"$ref": "#/components/schemas/SpanEndPayload"
}
]
],
"discriminator": {
"propertyName": "type"
}
}
},
"additionalProperties": false,
@ -7628,7 +7676,10 @@
{
"$ref": "#/components/schemas/StructuredLogEvent"
}
]
],
"discriminator": {
"propertyName": "type"
}
},
"ttl_seconds": {
"type": "integer"
@ -7958,7 +8009,10 @@
{
"$ref": "#/components/schemas/LLMRAGQueryGeneratorConfig"
}
]
],
"discriminator": {
"propertyName": "type"
}
},
"QueryRequest": {
"type": "object",
@ -8350,7 +8404,10 @@
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
]
],
"discriminator": {
"propertyName": "type"
}
}
},
"additionalProperties": false,
@ -8483,7 +8540,10 @@
{
"$ref": "#/components/schemas/AppEvalTaskConfig"
}
]
],
"discriminator": {
"propertyName": "type"
}
}
},
"additionalProperties": false,
@ -8632,7 +8692,10 @@
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
]
],
"discriminator": {
"propertyName": "type"
}
},
{
"type": "null"
@ -8683,7 +8746,10 @@
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
]
],
"discriminator": {
"propertyName": "type"
}
},
{
"type": "null"
@ -8860,7 +8926,10 @@
{
"$ref": "#/components/schemas/QATFinetuningConfig"
}
]
],
"discriminator": {
"propertyName": "type"
}
}
},
"additionalProperties": false,

View file

@ -76,6 +76,8 @@ components:
additionalProperties: false
properties:
step:
discriminator:
propertyName: step_type
oneOf:
- $ref: '#/components/schemas/InferenceStep'
- $ref: '#/components/schemas/ToolExecutionStep'
@ -119,6 +121,8 @@ components:
additionalProperties: false
properties:
payload:
discriminator:
propertyName: event_type
oneOf:
- $ref: '#/components/schemas/AgentTurnResponseStepStartPayload'
- $ref: '#/components/schemas/AgentTurnResponseStepProgressPayload'
@ -137,6 +141,8 @@ components:
default: step_complete
type: string
step_details:
discriminator:
propertyName: step_type
oneOf:
- $ref: '#/components/schemas/InferenceStep'
- $ref: '#/components/schemas/ToolExecutionStep'
@ -258,6 +264,8 @@ components:
additionalProperties: false
properties:
eval_candidate:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/ModelCandidate'
- $ref: '#/components/schemas/AgentCandidate'
@ -265,6 +273,8 @@ components:
type: integer
scoring_params:
additionalProperties:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
@ -402,6 +412,8 @@ components:
additionalProperties: false
properties:
eval_candidate:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/ModelCandidate'
- $ref: '#/components/schemas/AgentCandidate'
@ -619,6 +631,8 @@ components:
title: streamed completion response.
type: object
ContentDelta:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/TextDelta'
- $ref: '#/components/schemas/ImageDelta'
@ -897,6 +911,8 @@ components:
type: string
type: array
task_config:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/BenchmarkEvalTaskConfig'
- $ref: '#/components/schemas/AppEvalTaskConfig'
@ -1038,6 +1054,8 @@ components:
$ref: '#/components/schemas/InterleavedContentItem'
type: array
InterleavedContentItem:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/ImageContentItem'
- $ref: '#/components/schemas/TextContentItem'
@ -1244,6 +1262,8 @@ components:
additionalProperties: false
properties:
event:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/UnstructuredLogEvent'
- $ref: '#/components/schemas/MetricEvent'
@ -1325,6 +1345,8 @@ components:
- inserted_context
type: object
Message:
discriminator:
propertyName: role
oneOf:
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
@ -1495,6 +1517,8 @@ components:
- total_count
type: object
ParamType:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/StringType'
- $ref: '#/components/schemas/NumberType'
@ -1805,6 +1829,8 @@ components:
- max_chunks
type: object
RAGQueryGeneratorConfig:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
- $ref: '#/components/schemas/LLMRAGQueryGeneratorConfig'
@ -1922,6 +1948,8 @@ components:
description:
type: string
params:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
@ -2002,6 +2030,8 @@ components:
- embedding_model
type: object
ResponseFormat:
discriminator:
propertyName: type
oneOf:
- additionalProperties: false
properties:
@ -2063,6 +2093,8 @@ components:
additionalProperties: false
properties:
task_config:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/BenchmarkEvalTaskConfig'
- $ref: '#/components/schemas/AppEvalTaskConfig'
@ -2130,6 +2162,8 @@ components:
default: 1.0
type: number
strategy:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/GreedySamplingStrategy'
- $ref: '#/components/schemas/TopPSamplingStrategy'
@ -2167,7 +2201,9 @@ components:
scoring_functions:
additionalProperties:
oneOf:
- oneOf:
- discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
- $ref: '#/components/schemas/BasicScoringFnParams'
@ -2208,7 +2244,9 @@ components:
scoring_functions:
additionalProperties:
oneOf:
- oneOf:
- discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
- $ref: '#/components/schemas/BasicScoringFnParams'
@ -2246,6 +2284,8 @@ components:
- type: object
type: object
params:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
@ -2503,6 +2543,8 @@ components:
- type: object
type: object
payload:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/SpanStartPayload'
- $ref: '#/components/schemas/SpanEndPayload'
@ -2528,6 +2570,8 @@ components:
additionalProperties: false
properties:
algorithm_config:
discriminator:
propertyName: type
oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QATFinetuningConfig'
@ -3115,6 +3159,8 @@ components:
type: string
steps:
items:
discriminator:
propertyName: step_type
oneOf:
- $ref: '#/components/schemas/InferenceStep'
- $ref: '#/components/schemas/ToolExecutionStep'

View file

@ -35,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v6"
KEY_VERSION = "v7"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -8,7 +8,7 @@ import random
import pytest
from llama_stack_client.types.tool_runtime import DocumentParam
from llama_stack_client.types import Document
@pytest.fixture(scope="function")
@ -38,22 +38,22 @@ def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry
@pytest.fixture(scope="session")
def sample_documents():
return [
DocumentParam(
Document(
document_id="test-doc-1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
DocumentParam(
Document(
document_id="test-doc-2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
DocumentParam(
Document(
document_id="test-doc-3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
DocumentParam(
Document(
document_id="test-doc-4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},
@ -148,7 +148,7 @@ def test_vector_db_insert_from_url_and_query(
"llama3.rst",
]
documents = [
DocumentParam(
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",