mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-08 13:00:52 +00:00
fix(responses): type aliasing not supported for pydantic code generation and discrimintated unions
This commit is contained in:
parent
8fb17ba18e
commit
80b82c070c
5 changed files with 287 additions and 73 deletions
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
@ -14,21 +14,20 @@ from llama_stack.apis.tools.openai_tool_choice import (
|
|||
ToolChoiceCustom,
|
||||
ToolChoiceFunction,
|
||||
ToolChoiceMcp,
|
||||
ToolChoiceOptions,
|
||||
ToolChoiceTypes,
|
||||
)
|
||||
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
type OpenAIResponsesToolChoice = Annotated[
|
||||
Union[
|
||||
ToolChoiceTypes,
|
||||
ToolChoiceAllowed,
|
||||
ToolChoiceFunction,
|
||||
ToolChoiceMcp,
|
||||
ToolChoiceCustom
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
OpenAIResponsesToolChoice = (
|
||||
ToolChoiceOptions
|
||||
| ToolChoiceTypes # Multiple type values - can't use a discriminator here
|
||||
| Annotated[
|
||||
ToolChoiceAllowed | ToolChoiceFunction | ToolChoiceMcp | ToolChoiceCustom,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
)
|
||||
register_schema(OpenAIResponsesToolChoice, name="OpenAIResponsesToolChoice")
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
type ToolChoiceOptions = Literal["none", "auto", "required"]
|
||||
ToolChoiceOptions = Literal["none", "auto", "required"]
|
||||
register_schema(ToolChoiceOptions, name="ToolChoiceOptions")
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ class ToolChoiceTypes(BaseModel):
|
|||
"image_generation",
|
||||
"code_interpreter",
|
||||
]
|
||||
"""The type of hosted tool the model should to use.
|
||||
"""The type of hosted tool the model should use.
|
||||
|
||||
Allowed values are:
|
||||
|
||||
|
@ -61,7 +61,7 @@ class ToolChoiceAllowed(BaseModel):
|
|||
```
|
||||
"""
|
||||
|
||||
type: Literal["allowed_tools"]
|
||||
type: Literal["allowed_tools"] = "allowed_tools"
|
||||
"""Allowed tool configuration type. Always `allowed_tools`."""
|
||||
|
||||
|
||||
|
@ -70,7 +70,7 @@ class ToolChoiceFunction(BaseModel):
|
|||
name: str
|
||||
"""The name of the function to call."""
|
||||
|
||||
type: Literal["function"]
|
||||
type: Literal["function"] = "function"
|
||||
"""For function calling, the type is always `function`."""
|
||||
|
||||
|
||||
|
@ -79,7 +79,7 @@ class ToolChoiceMcp(BaseModel):
|
|||
server_label: str
|
||||
"""The label of the MCP server to use."""
|
||||
|
||||
type: Literal["mcp"]
|
||||
type: Literal["mcp"] = "mcp"
|
||||
"""For MCP tools, the type is always `mcp`."""
|
||||
|
||||
name: str | None = None
|
||||
|
@ -91,5 +91,5 @@ class ToolChoiceCustom(BaseModel):
|
|||
name: str
|
||||
"""The name of the custom tool to call."""
|
||||
|
||||
type: Literal["custom"]
|
||||
type: Literal["custom"] = "custom"
|
||||
"""For custom tool calling, the type is always `custom`."""
|
||||
|
|
|
@ -93,14 +93,7 @@ def get_class_property_docstrings(
|
|||
"""
|
||||
|
||||
result = {}
|
||||
# Check if the type has __mro__ (method resolution order)
|
||||
if hasattr(data_type, "__mro__"):
|
||||
bases = inspect.getmro(data_type)
|
||||
else:
|
||||
# For TypeAliasType or other types without __mro__, just use the type itself
|
||||
bases = [data_type] if hasattr(data_type, "__doc__") else []
|
||||
|
||||
for base in bases:
|
||||
for base in inspect.getmro(data_type):
|
||||
docstr = docstring.parse_type(base)
|
||||
for param in docstr.params.values():
|
||||
if param.name in result:
|
||||
|
@ -512,24 +505,13 @@ class JsonSchemaGenerator:
|
|||
(concrete_type,) = typing.get_args(typ)
|
||||
return self.type_to_schema(concrete_type)
|
||||
|
||||
# Check if this is a TypeAliasType (Python 3.12+) which doesn't have __mro__
|
||||
if hasattr(typ, "__mro__"):
|
||||
# dictionary of class attributes
|
||||
members = dict(inspect.getmembers(typ, lambda a: not inspect.isroutine(a)))
|
||||
property_docstrings = get_class_property_docstrings(typ, self.options.property_description_fun)
|
||||
else:
|
||||
# TypeAliasType or other types without __mro__
|
||||
members = {}
|
||||
property_docstrings = {}
|
||||
# dictionary of class attributes
|
||||
members = dict(inspect.getmembers(typ, lambda a: not inspect.isroutine(a)))
|
||||
|
||||
property_docstrings = get_class_property_docstrings(typ, self.options.property_description_fun)
|
||||
properties: Dict[str, Schema] = {}
|
||||
required: List[str] = []
|
||||
# Only process properties if the type supports class properties
|
||||
if hasattr(typ, "__mro__"):
|
||||
class_properties = get_class_properties(typ)
|
||||
else:
|
||||
class_properties = []
|
||||
|
||||
for property_name, property_type in class_properties:
|
||||
for property_name, property_type in get_class_properties(typ):
|
||||
# rename property if an alias name is specified
|
||||
alias = get_annotation(property_type, Alias)
|
||||
if alias:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue