mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Introduce extract_type_annotation method
Signed-off-by: thepetk <thepetk@gmail.com>
This commit is contained in:
parent
63887f2a21
commit
58e81839a4
44 changed files with 87 additions and 71 deletions
|
|
@ -5,11 +5,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from types import UnionType
|
||||
from typing import Annotated, Any, Union, get_args, get_origin
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
|
@ -54,6 +54,41 @@ class ChangedPathTracker:
|
|||
return self._changed_paths
|
||||
|
||||
|
||||
def extract_type_annotation(annotation: Any) -> str:
|
||||
"""extract a type annotation into a clean string representation."""
|
||||
if annotation is None:
|
||||
return "Any"
|
||||
|
||||
if annotation is type(None):
|
||||
return "None"
|
||||
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
# recursive workaround for Annotated types to ignore FieldInfo part
|
||||
if origin is Annotated and args:
|
||||
return extract_type_annotation(args[0])
|
||||
|
||||
if origin in [Union, UnionType]:
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
has_none = len(non_none_args) < len(args)
|
||||
|
||||
if len(non_none_args) == 1:
|
||||
formatted = extract_type_annotation(non_none_args[0])
|
||||
return f"{formatted} | None" if has_none else formatted
|
||||
else:
|
||||
formatted_args = [extract_type_annotation(arg) for arg in non_none_args]
|
||||
result = " | ".join(formatted_args)
|
||||
return f"{result} | None" if has_none else result
|
||||
|
||||
if origin is not None and args:
|
||||
origin_name = getattr(origin, "__name__", str(origin))
|
||||
formatted_args = [extract_type_annotation(arg) for arg in args]
|
||||
return f"{origin_name}[{', '.join(formatted_args)}]"
|
||||
|
||||
return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
|
||||
|
||||
|
||||
def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
||||
"""Extract configuration information from a config class."""
|
||||
try:
|
||||
|
|
@ -84,27 +119,8 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|||
for field_name, field in config_class.model_fields.items():
|
||||
if getattr(field, "exclude", False):
|
||||
continue
|
||||
field_type = str(field.annotation) if field.annotation else "Any"
|
||||
|
||||
# this string replace is ridiculous
|
||||
field_type = (
|
||||
field_type.replace("typing.", "")
|
||||
.replace("Optional[", "")
|
||||
.replace("]", "")
|
||||
)
|
||||
field_type = (
|
||||
field_type.replace("Annotated[", "")
|
||||
.replace("FieldInfo(", "")
|
||||
.replace(")", "")
|
||||
)
|
||||
field_type = field_type.replace("llama_stack_api.inference.", "")
|
||||
field_type = field_type.replace("llama_stack.providers.", "")
|
||||
field_type = field_type.replace("llama_stack_api.datatypes.", "")
|
||||
|
||||
field_type = re.sub(r"Optional\[([^\]]+)\]", r"\1 | None", field_type)
|
||||
field_type = re.sub(r"Annotated\[([^,]+),.*?\]", r"\1", field_type)
|
||||
field_type = re.sub(r"FieldInfo\([^)]*\)", "", field_type)
|
||||
field_type = re.sub(r"<class '([^']+)'>", r"\1", field_type)
|
||||
field_type = extract_type_annotation(field.annotation)
|
||||
|
||||
default_value = field.default
|
||||
if field.default_factory is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue