diff --git a/src/llama_stack/strong_typing/inspection.py b/src/llama_stack/strong_typing/inspection.py index d3ebc7585..319d12657 100644 --- a/src/llama_stack/strong_typing/inspection.py +++ b/src/llama_stack/strong_typing/inspection.py @@ -430,6 +430,32 @@ def _unwrap_generic_list(typ: type[list[T]]) -> type[T]: return list_type # type: ignore[no-any-return] +def is_generic_sequence(typ: object) -> bool: + "True if the specified type is a generic Sequence, i.e. `Sequence[T]`." + import collections.abc + + typ = unwrap_annotated_type(typ) + return typing.get_origin(typ) is collections.abc.Sequence + + +def unwrap_generic_sequence(typ: object) -> type: + """ + Extracts the item type of a Sequence type. + + :param typ: The Sequence type `Sequence[T]`. + :returns: The item type `T`. + """ + + return rewrap_annotated_type(_unwrap_generic_sequence, typ) # type: ignore[arg-type] + + +def _unwrap_generic_sequence(typ: object) -> type: + "Extracts the item type of a Sequence type (e.g. returns `T` for `Sequence[T]`)." + + (sequence_type,) = typing.get_args(typ) # unpack single tuple element + return sequence_type # type: ignore[no-any-return] + + def is_generic_set(typ: object) -> TypeGuard[type[set]]: "True if the specified type is a generic set, i.e. `Set[T]`." diff --git a/src/llama_stack/strong_typing/name.py b/src/llama_stack/strong_typing/name.py index 00cdc2ae2..60501ac43 100644 --- a/src/llama_stack/strong_typing/name.py +++ b/src/llama_stack/strong_typing/name.py @@ -18,10 +18,12 @@ from .inspection import ( TypeLike, is_generic_dict, is_generic_list, + is_generic_sequence, is_type_optional, is_type_union, unwrap_generic_dict, unwrap_generic_list, + unwrap_generic_sequence, unwrap_optional_type, unwrap_union_types, ) @@ -155,24 +157,28 @@ def python_type_to_name(data_type: TypeLike, force: bool = False) -> str: if metadata is not None: # type is Annotated[T, ...] arg = typing.get_args(data_type)[0] - return python_type_to_name(arg) + return python_type_to_name(arg, force=force) if force: # generic types if is_type_optional(data_type, strict=True): - inner_name = python_type_to_name(unwrap_optional_type(data_type)) + inner_name = python_type_to_name(unwrap_optional_type(data_type), force=True) return f"Optional__{inner_name}" elif is_generic_list(data_type): - item_name = python_type_to_name(unwrap_generic_list(data_type)) + item_name = python_type_to_name(unwrap_generic_list(data_type), force=True) + return f"List__{item_name}" + elif is_generic_sequence(data_type): + # Treat Sequence the same as List for schema generation purposes + item_name = python_type_to_name(unwrap_generic_sequence(data_type), force=True) return f"List__{item_name}" elif is_generic_dict(data_type): key_type, value_type = unwrap_generic_dict(data_type) - key_name = python_type_to_name(key_type) - value_name = python_type_to_name(value_type) + key_name = python_type_to_name(key_type, force=True) + value_name = python_type_to_name(value_type, force=True) return f"Dict__{key_name}__{value_name}" elif is_type_union(data_type): member_types = unwrap_union_types(data_type) - member_names = "__".join(python_type_to_name(member_type) for member_type in member_types) + member_names = "__".join(python_type_to_name(member_type, force=True) for member_type in member_types) return f"Union__{member_names}" # named system or user-defined type diff --git a/src/llama_stack/strong_typing/schema.py b/src/llama_stack/strong_typing/schema.py index 15a3bbbfc..916690e41 100644 --- a/src/llama_stack/strong_typing/schema.py +++ b/src/llama_stack/strong_typing/schema.py @@ -111,7 +111,7 @@ def get_class_property_docstrings( def docstring_to_schema(data_type: type) -> Schema: short_description, long_description = get_class_docstrings(data_type) schema: Schema = { - "title": python_type_to_name(data_type), + "title": python_type_to_name(data_type, force=True), } description = "\n".join(filter(None, [short_description, long_description])) @@ -417,6 +417,10 @@ class JsonSchemaGenerator: if origin_type is list: (list_type,) = typing.get_args(typ) # unpack single tuple element return {"type": "array", "items": self.type_to_schema(list_type)} + elif origin_type is collections.abc.Sequence: + # Treat Sequence the same as list for JSON schema (both are arrays) + (sequence_type,) = typing.get_args(typ) # unpack single tuple element + return {"type": "array", "items": self.type_to_schema(sequence_type)} elif origin_type is dict: key_type, value_type = typing.get_args(typ) if not (key_type is str or key_type is int or is_type_enum(key_type)):