diff --git a/llama_toolchain/common/prompt_for_config.py b/llama_toolchain/common/prompt_for_config.py index 071d40cb6..35e2aaa69 100644 --- a/llama_toolchain/common/prompt_for_config.py +++ b/llama_toolchain/common/prompt_for_config.py @@ -8,9 +8,10 @@ import inspect import json from enum import Enum -from typing import get_args, get_origin, List, Literal, Optional, Union +from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union from pydantic import BaseModel +from pydantic.fields import ModelField from typing_extensions import Annotated @@ -42,6 +43,23 @@ def get_non_none_type(field_type): return next(arg for arg in get_args(field_type) if arg is not type(None)) +def manually_validate_field(model: Type[BaseModel], field: ModelField, value: Any): + validators = field.class_validators.values() + + for validator in validators: + if validator.pre: + value = validator.func(model, value) + + # Apply type coercion + value = field.type_(value) + + for validator in validators: + if not validator.pre: + value = validator.func(model, value) + + return value + + # This is somewhat elaborate, but does not purport to be comprehensive in any way. # We should add handling for the most common cases to tide us over. # @@ -85,7 +103,9 @@ def prompt_for_config( # this branch does not handle existing and default values yet user_input = input(prompt + " ") try: - config_data[field_name] = field_type[user_input] + value = field_type[user_input] + validated_value = manually_validate_field(config_type, field, value) + config_data[field_name] = validated_value break except KeyError: print( @@ -178,51 +198,59 @@ def prompt_for_config( else: print("This field is required. Please provide a value.") continue + else: + try: + # Handle Optional types + if is_optional(field_type): + if user_input.lower() == "none": + value = None + else: + field_type = get_non_none_type(field_type) + value = user_input + + # Handle List of primitives + elif is_list_of_primitives(field_type): + try: + value = json.loads(user_input) + if not isinstance(value, list): + raise ValueError( + "Input must be a JSON-encoded list" + ) + element_type = get_args(field_type)[0] + value = [element_type(item) for item in value] + + except json.JSONDecodeError: + print( + "Invalid JSON. Please enter a valid JSON-encoded list." + ) + continue + except ValueError as e: + print(f"{str(e)}") + continue + + # Convert the input to the correct type + elif inspect.isclass(field_type) and issubclass( + field_type, BaseModel + ): + # For nested BaseModels, we assume a dictionary-like string input + import ast + + value = field_type(**ast.literal_eval(user_input)) + else: + value = field_type(user_input) + + except ValueError: + print( + f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}" + ) + continue try: - # Handle Optional types - if is_optional(field_type): - if user_input.lower() == "none": - config_data[field_name] = None - break - field_type = get_non_none_type(field_type) - - # Handle List of primitives - if is_list_of_primitives(field_type): - try: - value = json.loads(user_input) - if not isinstance(value, list): - raise ValueError("Input must be a JSON-encoded list") - element_type = get_args(field_type)[0] - config_data[field_name] = [ - element_type(item) for item in value - ] - break - except json.JSONDecodeError: - print( - "Invalid JSON. Please enter a valid JSON-encoded list." - ) - continue - except ValueError as e: - print(f"{str(e)}") - continue - - # Convert the input to the correct type - if inspect.isclass(field_type) and issubclass( - field_type, BaseModel - ): - # For nested BaseModels, we assume a dictionary-like string input - import ast - - config_data[field_name] = field_type( - **ast.literal_eval(user_input) - ) - else: - config_data[field_name] = field_type(user_input) + # Validate the field using our manual validation function + validated_value = manually_validate_field(config_type, field, value) + config_data[field_name] = validated_value break - except ValueError: - print( - f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}" - ) + except ValueError as e: + print(f"Validation error: {str(e)}") return config_type(**config_data) diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_toolchain/inference/meta_reference/config.py index b921f7347..6757b1cfa 100644 --- a/llama_toolchain/inference/meta_reference/config.py +++ b/llama_toolchain/inference/meta_reference/config.py @@ -6,9 +6,12 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type +from llama_models.datatypes import ModelFamily -from pydantic import BaseModel +from llama_models.schema_utils import json_schema_type +from llama_models.sku_list import all_registered_models + +from pydantic import BaseModel, validator from llama_toolchain.inference.api import QuantizationConfig @@ -20,3 +23,18 @@ class MetaReferenceImplConfig(BaseModel): torch_seed: Optional[int] = None max_seq_len: int max_batch_size: int = 1 + + @validator("model") + @classmethod + def validate_model(cls, model: str) -> str: + permitted_models = [ + m.descriptor() + for m in all_registered_models() + if m.model_family == ModelFamily.llama3_1 + ] + if model not in permitted_models: + model_list = "\n\t".join(permitted_models) + raise ValueError( + f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]" + ) + return model diff --git a/llama_toolchain/safety/meta_reference/config.py b/llama_toolchain/safety/meta_reference/config.py index 8022e4f58..4d68d2e48 100644 --- a/llama_toolchain/safety/meta_reference/config.py +++ b/llama_toolchain/safety/meta_reference/config.py @@ -6,7 +6,9 @@ from typing import List, Optional -from pydantic import BaseModel +from llama_models.sku_list import CoreModelId, safety_models + +from pydantic import BaseModel, validator class LlamaGuardShieldConfig(BaseModel): @@ -15,10 +17,38 @@ class LlamaGuardShieldConfig(BaseModel): disable_input_check: bool = False disable_output_check: bool = False + @validator("model") + @classmethod + def validate_model(cls, model: str) -> str: + permitted_models = [ + m.descriptor() + for m in safety_models() + if m.core_model_id == CoreModelId.llama_guard_3_8b + ] + if model not in permitted_models: + raise ValueError( + f"Invalid model: {model}. Must be one of {permitted_models}" + ) + return model + class PromptGuardShieldConfig(BaseModel): model: str = "Prompt-Guard-86M" + @validator("model") + @classmethod + def validate_model(cls, model: str) -> str: + permitted_models = [ + m.descriptor() + for m in safety_models() + if m.core_model_id == CoreModelId.prompt_guard_86m + ] + if model not in permitted_models: + raise ValueError( + f"Invalid model: {model}. Must be one of {permitted_models}" + ) + return model + class SafetyConfig(BaseModel): llama_guard_shield: Optional[LlamaGuardShieldConfig] = None