From 37da47ef8ee9234f370b3105d006ef20fb3cacab Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Mon, 12 Aug 2024 15:14:21 -0700 Subject: [PATCH] upgrade pydantic to latest --- .../agentic_system/api/datatypes.py | 6 +++- llama_toolchain/common/prompt_for_config.py | 30 ++++++++----------- .../inference/meta_reference/config.py | 6 ++-- requirements.txt | 3 +- 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index 0c8c1f4c8..1dda64834 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -10,7 +10,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated from llama_toolchain.common.deployment_types import * # noqa: F403 @@ -42,6 +42,8 @@ class StepType(Enum): @json_schema_type class InferenceStep(StepCommon): + model_config = ConfigDict(protected_namespaces=()) + step_type: Literal[StepType.inference.value] = StepType.inference.value model_response: CompletionMessage @@ -157,6 +159,8 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel): @json_schema_type class AgenticSystemTurnResponseStepProgressPayload(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = ( AgenticSystemTurnResponseEventType.step_progress.value ) diff --git a/llama_toolchain/common/prompt_for_config.py b/llama_toolchain/common/prompt_for_config.py index d29180520..6c53477d8 100644 --- a/llama_toolchain/common/prompt_for_config.py +++ b/llama_toolchain/common/prompt_for_config.py @@ -11,7 +11,7 @@ from enum import Enum from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union from pydantic import BaseModel -from pydantic.fields import ModelField +from pydantic_core import PydanticUndefinedType from typing_extensions import Annotated @@ -43,19 +43,11 @@ def get_non_none_type(field_type): return next(arg for arg in get_args(field_type) if arg is not type(None)) -def manually_validate_field(model: Type[BaseModel], field: ModelField, value: Any): - validators = field.class_validators.values() - - for validator in validators: - if validator.pre: - value = validator.func(model, value) - - # Apply type coercion - value = field.type_(value) - - for validator in validators: - if not validator.pre: - value = validator.func(model, value) +def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any): + validators = model.__pydantic_decorators__.field_validators + for _name, validator in validators.items(): + if field_name in validator.info.fields: + validator.func(value) return value @@ -89,9 +81,11 @@ def prompt_for_config( default_value = existing_value else: default_value = ( - field.default if not isinstance(field.default, type(Ellipsis)) else None + field.default + if not isinstance(field.default, PydanticUndefinedType) + else None ) - is_required = field.required + is_required = field.is_required # Skip fields with Literal type if get_origin(field_type) is Literal: @@ -247,7 +241,9 @@ def prompt_for_config( try: # Validate the field using our manual validation function - validated_value = manually_validate_field(config_type, field, value) + validated_value = manually_validate_field( + config_type, field_name, value + ) config_data[field_name] = validated_value break except ValueError as e: diff --git a/llama_toolchain/inference/meta_reference/config.py b/llama_toolchain/inference/meta_reference/config.py index d9aef32e6..d2e601680 100644 --- a/llama_toolchain/inference/meta_reference/config.py +++ b/llama_toolchain/inference/meta_reference/config.py @@ -11,9 +11,9 @@ from llama_models.datatypes import ModelFamily from llama_models.schema_utils import json_schema_type from llama_models.sku_list import all_registered_models -from llama_toolchain.inference.api import QuantizationConfig +from pydantic import BaseModel, Field, field_validator -from pydantic import BaseModel, Field, validator +from llama_toolchain.inference.api import QuantizationConfig @json_schema_type @@ -27,7 +27,7 @@ class MetaReferenceImplConfig(BaseModel): max_seq_len: int max_batch_size: int = 1 - @validator("model") + @field_validator("model") @classmethod def validate_model(cls, model: str) -> str: permitted_models = [ diff --git a/requirements.txt b/requirements.txt index 8d3379e7b..f13e0819b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,5 @@ fire httpx huggingface-hub llama-models -pydantic==1.10.13 -pydantic_core==2.18.2 +pydantic requests