mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -31,15 +31,11 @@ def is_list_of_primitives(field_type):
|
|||
|
||||
|
||||
def is_basemodel_without_fields(typ):
|
||||
return (
|
||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
||||
)
|
||||
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
|
||||
|
||||
|
||||
def can_recurse(typ):
|
||||
return (
|
||||
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||
)
|
||||
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
|
||||
|
||||
|
||||
def get_literal_values(field):
|
||||
|
@ -72,7 +68,7 @@ def is_discriminated_union(typ) -> bool:
|
|||
if isinstance(typ, FieldInfo):
|
||||
return typ.discriminator
|
||||
else:
|
||||
if not (get_origin(typ) is Annotated):
|
||||
if get_origin(typ) is not Annotated:
|
||||
return False
|
||||
args = get_args(typ)
|
||||
return len(args) >= 2 and args[1].discriminator
|
||||
|
@ -116,9 +112,7 @@ def prompt_for_discriminated_union(
|
|||
chosen_type = type_map[discriminator_value]
|
||||
log.info(f"\nConfiguring {chosen_type.__name__}:")
|
||||
|
||||
if existing_value and (
|
||||
getattr(existing_value, discriminator) != discriminator_value
|
||||
):
|
||||
if existing_value and (getattr(existing_value, discriminator) != discriminator_value):
|
||||
existing_value = None
|
||||
|
||||
sub_config = prompt_for_config(chosen_type, existing_value)
|
||||
|
@ -134,9 +128,7 @@ def prompt_for_discriminated_union(
|
|||
#
|
||||
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
||||
# unit tests for coverage.
|
||||
def prompt_for_config(
|
||||
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
|
||||
) -> BaseModel:
|
||||
def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel:
|
||||
"""
|
||||
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
||||
|
||||
|
@ -150,17 +142,11 @@ def prompt_for_config(
|
|||
|
||||
for field_name, field in config_type.__fields__.items():
|
||||
field_type = field.annotation
|
||||
existing_value = (
|
||||
getattr(existing_config, field_name) if existing_config else None
|
||||
)
|
||||
existing_value = getattr(existing_config, field_name) if existing_config else None
|
||||
if existing_value:
|
||||
default_value = existing_value
|
||||
else:
|
||||
default_value = (
|
||||
field.default
|
||||
if not isinstance(field.default, PydanticUndefinedType)
|
||||
else None
|
||||
)
|
||||
default_value = field.default if not isinstance(field.default, PydanticUndefinedType) else None
|
||||
is_required = field.is_required
|
||||
|
||||
# Skip fields with Literal type
|
||||
|
@ -183,15 +169,11 @@ def prompt_for_config(
|
|||
config_data[field_name] = validated_value
|
||||
break
|
||||
except KeyError:
|
||||
log.error(
|
||||
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
|
||||
)
|
||||
log.error(f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}")
|
||||
continue
|
||||
|
||||
if is_discriminated_union(field):
|
||||
config_data[field_name] = prompt_for_discriminated_union(
|
||||
field_name, field, existing_value
|
||||
)
|
||||
config_data[field_name] = prompt_for_discriminated_union(field_name, field, existing_value)
|
||||
continue
|
||||
|
||||
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
|
||||
|
@ -202,9 +184,7 @@ def prompt_for_config(
|
|||
nested_type = get_non_none_type(field_type)
|
||||
log.info(f"Entering sub-configuration for {field_name}:")
|
||||
config_data[field_name] = prompt_for_config(nested_type, existing_value)
|
||||
elif is_optional(field_type) and is_discriminated_union(
|
||||
get_non_none_type(field_type)
|
||||
):
|
||||
elif is_optional(field_type) and is_discriminated_union(get_non_none_type(field_type)):
|
||||
prompt = f"Do you want to configure {field_name}? (y/n): "
|
||||
if input(prompt).lower() == "n":
|
||||
config_data[field_name] = None
|
||||
|
@ -260,16 +240,12 @@ def prompt_for_config(
|
|||
try:
|
||||
value = json.loads(user_input)
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(
|
||||
"Input must be a JSON-encoded list"
|
||||
)
|
||||
raise ValueError("Input must be a JSON-encoded list")
|
||||
element_type = get_args(field_type)[0]
|
||||
value = [element_type(item) for item in value]
|
||||
|
||||
except json.JSONDecodeError:
|
||||
log.error(
|
||||
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
|
||||
)
|
||||
log.error('Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]')
|
||||
continue
|
||||
except ValueError as e:
|
||||
log.error(f"{str(e)}")
|
||||
|
@ -279,20 +255,14 @@ def prompt_for_config(
|
|||
try:
|
||||
value = json.loads(user_input)
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(
|
||||
"Input must be a JSON-encoded dictionary"
|
||||
)
|
||||
raise ValueError("Input must be a JSON-encoded dictionary")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
log.error(
|
||||
"Invalid JSON. Please enter a valid JSON-encoded dict."
|
||||
)
|
||||
log.error("Invalid JSON. Please enter a valid JSON-encoded dict.")
|
||||
continue
|
||||
|
||||
# Convert the input to the correct type
|
||||
elif inspect.isclass(field_type) and issubclass(
|
||||
field_type, BaseModel
|
||||
):
|
||||
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
# For nested BaseModels, we assume a dictionary-like string input
|
||||
import ast
|
||||
|
||||
|
@ -301,16 +271,12 @@ def prompt_for_config(
|
|||
value = field_type(user_input)
|
||||
|
||||
except ValueError:
|
||||
log.error(
|
||||
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
|
||||
)
|
||||
log.error(f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Validate the field using our manual validation function
|
||||
validated_value = manually_validate_field(
|
||||
config_type, field_name, value
|
||||
)
|
||||
validated_value = manually_validate_field(config_type, field_name, value)
|
||||
config_data[field_name] = validated_value
|
||||
break
|
||||
except ValueError as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue