upgrade pydantic to latest

This commit is contained in:
Hardik Shah 2024-08-12 15:14:21 -07:00
parent 2cd8b2ff5b
commit 37da47ef8e
4 changed files with 22 additions and 23 deletions

View file

@ -10,7 +10,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_toolchain.common.deployment_types import * # noqa: F403 from llama_toolchain.common.deployment_types import * # noqa: F403
@ -42,6 +42,8 @@ class StepType(Enum):
@json_schema_type @json_schema_type
class InferenceStep(StepCommon): class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value step_type: Literal[StepType.inference.value] = StepType.inference.value
model_response: CompletionMessage model_response: CompletionMessage
@ -157,6 +159,8 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
@json_schema_type @json_schema_type
class AgenticSystemTurnResponseStepProgressPayload(BaseModel): class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = ( event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
AgenticSystemTurnResponseEventType.step_progress.value AgenticSystemTurnResponseEventType.step_progress.value
) )

View file

@ -11,7 +11,7 @@ from enum import Enum
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import ModelField from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated from typing_extensions import Annotated
@ -43,19 +43,11 @@ def get_non_none_type(field_type):
return next(arg for arg in get_args(field_type) if arg is not type(None)) return next(arg for arg in get_args(field_type) if arg is not type(None))
def manually_validate_field(model: Type[BaseModel], field: ModelField, value: Any): def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any):
validators = field.class_validators.values() validators = model.__pydantic_decorators__.field_validators
for _name, validator in validators.items():
for validator in validators: if field_name in validator.info.fields:
if validator.pre: validator.func(value)
value = validator.func(model, value)
# Apply type coercion
value = field.type_(value)
for validator in validators:
if not validator.pre:
value = validator.func(model, value)
return value return value
@ -89,9 +81,11 @@ def prompt_for_config(
default_value = existing_value default_value = existing_value
else: else:
default_value = ( default_value = (
field.default if not isinstance(field.default, type(Ellipsis)) else None field.default
if not isinstance(field.default, PydanticUndefinedType)
else None
) )
is_required = field.required is_required = field.is_required
# Skip fields with Literal type # Skip fields with Literal type
if get_origin(field_type) is Literal: if get_origin(field_type) is Literal:
@ -247,7 +241,9 @@ def prompt_for_config(
try: try:
# Validate the field using our manual validation function # Validate the field using our manual validation function
validated_value = manually_validate_field(config_type, field, value) validated_value = manually_validate_field(
config_type, field_name, value
)
config_data[field_name] = validated_value config_data[field_name] = validated_value
break break
except ValueError as e: except ValueError as e:

View file

@ -11,9 +11,9 @@ from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_toolchain.inference.api import QuantizationConfig from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, validator from llama_toolchain.inference.api import QuantizationConfig
@json_schema_type @json_schema_type
@ -27,7 +27,7 @@ class MetaReferenceImplConfig(BaseModel):
max_seq_len: int max_seq_len: int
max_batch_size: int = 1 max_batch_size: int = 1
@validator("model") @field_validator("model")
@classmethod @classmethod
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:
permitted_models = [ permitted_models = [

View file

@ -2,6 +2,5 @@ fire
httpx httpx
huggingface-hub huggingface-hub
llama-models llama-models
pydantic==1.10.13 pydantic
pydantic_core==2.18.2
requests requests