forked from phoenix-oss/llama-stack-mirror
upgrade pydantic to latest
This commit is contained in:
parent
2cd8b2ff5b
commit
37da47ef8e
4 changed files with 22 additions and 23 deletions
|
@ -10,7 +10,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
|||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_toolchain.common.deployment_types import * # noqa: F403
|
||||
|
@ -42,6 +42,8 @@ class StepType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||
model_response: CompletionMessage
|
||||
|
||||
|
@ -157,6 +159,8 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
||||
AgenticSystemTurnResponseEventType.step_progress.value
|
||||
)
|
||||
|
|
|
@ -11,7 +11,7 @@ from enum import Enum
|
|||
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import ModelField
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
@ -43,19 +43,11 @@ def get_non_none_type(field_type):
|
|||
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||
|
||||
|
||||
def manually_validate_field(model: Type[BaseModel], field: ModelField, value: Any):
|
||||
validators = field.class_validators.values()
|
||||
|
||||
for validator in validators:
|
||||
if validator.pre:
|
||||
value = validator.func(model, value)
|
||||
|
||||
# Apply type coercion
|
||||
value = field.type_(value)
|
||||
|
||||
for validator in validators:
|
||||
if not validator.pre:
|
||||
value = validator.func(model, value)
|
||||
def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any):
|
||||
validators = model.__pydantic_decorators__.field_validators
|
||||
for _name, validator in validators.items():
|
||||
if field_name in validator.info.fields:
|
||||
validator.func(value)
|
||||
|
||||
return value
|
||||
|
||||
|
@ -89,9 +81,11 @@ def prompt_for_config(
|
|||
default_value = existing_value
|
||||
else:
|
||||
default_value = (
|
||||
field.default if not isinstance(field.default, type(Ellipsis)) else None
|
||||
field.default
|
||||
if not isinstance(field.default, PydanticUndefinedType)
|
||||
else None
|
||||
)
|
||||
is_required = field.required
|
||||
is_required = field.is_required
|
||||
|
||||
# Skip fields with Literal type
|
||||
if get_origin(field_type) is Literal:
|
||||
|
@ -247,7 +241,9 @@ def prompt_for_config(
|
|||
|
||||
try:
|
||||
# Validate the field using our manual validation function
|
||||
validated_value = manually_validate_field(config_type, field, value)
|
||||
validated_value = manually_validate_field(
|
||||
config_type, field_name, value
|
||||
)
|
||||
config_data[field_name] = validated_value
|
||||
break
|
||||
except ValueError as e:
|
||||
|
|
|
@ -11,9 +11,9 @@ from llama_models.datatypes import ModelFamily
|
|||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_toolchain.inference.api import QuantizationConfig
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from llama_toolchain.inference.api import QuantizationConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -27,7 +27,7 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
|
||||
@validator("model")
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
|
|
|
@ -2,6 +2,5 @@ fire
|
|||
httpx
|
||||
huggingface-hub
|
||||
llama-models
|
||||
pydantic==1.10.13
|
||||
pydantic_core==2.18.2
|
||||
pydantic
|
||||
requests
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue