Introduce extract_type_annotation method

Signed-off-by: thepetk <thepetk@gmail.com>
This commit is contained in:
thepetk 2025-11-07 21:00:59 +00:00
parent 63887f2a21
commit 58e81839a4
44 changed files with 87 additions and 71 deletions

View file

@ -5,11 +5,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import subprocess
import sys
from pathlib import Path
from typing import Any
from types import UnionType
from typing import Annotated, Any, Union, get_args, get_origin
from pydantic_core import PydanticUndefined
from rich.progress import Progress, SpinnerColumn, TextColumn
@ -54,6 +54,41 @@ class ChangedPathTracker:
return self._changed_paths
def extract_type_annotation(annotation: Any) -> str:
"""extract a type annotation into a clean string representation."""
if annotation is None:
return "Any"
if annotation is type(None):
return "None"
origin = get_origin(annotation)
args = get_args(annotation)
# recursive workaround for Annotated types to ignore FieldInfo part
if origin is Annotated and args:
return extract_type_annotation(args[0])
if origin in [Union, UnionType]:
non_none_args = [arg for arg in args if arg is not type(None)]
has_none = len(non_none_args) < len(args)
if len(non_none_args) == 1:
formatted = extract_type_annotation(non_none_args[0])
return f"{formatted} | None" if has_none else formatted
else:
formatted_args = [extract_type_annotation(arg) for arg in non_none_args]
result = " | ".join(formatted_args)
return f"{result} | None" if has_none else result
if origin is not None and args:
origin_name = getattr(origin, "__name__", str(origin))
formatted_args = [extract_type_annotation(arg) for arg in args]
return f"{origin_name}[{', '.join(formatted_args)}]"
return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
def get_config_class_info(config_class_path: str) -> dict[str, Any]:
"""Extract configuration information from a config class."""
try:
@ -84,27 +119,8 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
for field_name, field in config_class.model_fields.items():
if getattr(field, "exclude", False):
continue
field_type = str(field.annotation) if field.annotation else "Any"
# this string replace is ridiculous
field_type = (
field_type.replace("typing.", "")
.replace("Optional[", "")
.replace("]", "")
)
field_type = (
field_type.replace("Annotated[", "")
.replace("FieldInfo(", "")
.replace(")", "")
)
field_type = field_type.replace("llama_stack_api.inference.", "")
field_type = field_type.replace("llama_stack.providers.", "")
field_type = field_type.replace("llama_stack_api.datatypes.", "")
field_type = re.sub(r"Optional\[([^\]]+)\]", r"\1 | None", field_type)
field_type = re.sub(r"Annotated\[([^,]+),.*?\]", r"\1", field_type)
field_type = re.sub(r"FieldInfo\([^)]*\)", "", field_type)
field_type = re.sub(r"<class '([^']+)'>", r"\1", field_type)
field_type = extract_type_annotation(field.annotation)
default_value = field.default
if field.default_factory is not None: