forked from phoenix-oss/llama-stack-mirror
upgrade pydantic to latest
This commit is contained in:
parent
2cd8b2ff5b
commit
37da47ef8e
4 changed files with 22 additions and 23 deletions
|
@ -10,7 +10,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_toolchain.common.deployment_types import * # noqa: F403
|
from llama_toolchain.common.deployment_types import * # noqa: F403
|
||||||
|
@ -42,6 +42,8 @@ class StepType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class InferenceStep(StepCommon):
|
class InferenceStep(StepCommon):
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||||
model_response: CompletionMessage
|
model_response: CompletionMessage
|
||||||
|
|
||||||
|
@ -157,6 +159,8 @@ class AgenticSystemTurnResponseStepCompletePayload(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
class AgenticSystemTurnResponseStepProgressPayload(BaseModel):
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
event_type: Literal[AgenticSystemTurnResponseEventType.step_progress.value] = (
|
||||||
AgenticSystemTurnResponseEventType.step_progress.value
|
AgenticSystemTurnResponseEventType.step_progress.value
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from enum import Enum
|
||||||
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import ModelField
|
from pydantic_core import PydanticUndefinedType
|
||||||
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
@ -43,19 +43,11 @@ def get_non_none_type(field_type):
|
||||||
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||||
|
|
||||||
|
|
||||||
def manually_validate_field(model: Type[BaseModel], field: ModelField, value: Any):
|
def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any):
|
||||||
validators = field.class_validators.values()
|
validators = model.__pydantic_decorators__.field_validators
|
||||||
|
for _name, validator in validators.items():
|
||||||
for validator in validators:
|
if field_name in validator.info.fields:
|
||||||
if validator.pre:
|
validator.func(value)
|
||||||
value = validator.func(model, value)
|
|
||||||
|
|
||||||
# Apply type coercion
|
|
||||||
value = field.type_(value)
|
|
||||||
|
|
||||||
for validator in validators:
|
|
||||||
if not validator.pre:
|
|
||||||
value = validator.func(model, value)
|
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@ -89,9 +81,11 @@ def prompt_for_config(
|
||||||
default_value = existing_value
|
default_value = existing_value
|
||||||
else:
|
else:
|
||||||
default_value = (
|
default_value = (
|
||||||
field.default if not isinstance(field.default, type(Ellipsis)) else None
|
field.default
|
||||||
|
if not isinstance(field.default, PydanticUndefinedType)
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
is_required = field.required
|
is_required = field.is_required
|
||||||
|
|
||||||
# Skip fields with Literal type
|
# Skip fields with Literal type
|
||||||
if get_origin(field_type) is Literal:
|
if get_origin(field_type) is Literal:
|
||||||
|
@ -247,7 +241,9 @@ def prompt_for_config(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate the field using our manual validation function
|
# Validate the field using our manual validation function
|
||||||
validated_value = manually_validate_field(config_type, field, value)
|
validated_value = manually_validate_field(
|
||||||
|
config_type, field_name, value
|
||||||
|
)
|
||||||
config_data[field_name] = validated_value
|
config_data[field_name] = validated_value
|
||||||
break
|
break
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
|
@ -11,9 +11,9 @@ from llama_models.datatypes import ModelFamily
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_toolchain.inference.api import QuantizationConfig
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
from llama_toolchain.inference.api import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -27,7 +27,7 @@ class MetaReferenceImplConfig(BaseModel):
|
||||||
max_seq_len: int
|
max_seq_len: int
|
||||||
max_batch_size: int = 1
|
max_batch_size: int = 1
|
||||||
|
|
||||||
@validator("model")
|
@field_validator("model")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model(cls, model: str) -> str:
|
def validate_model(cls, model: str) -> str:
|
||||||
permitted_models = [
|
permitted_models = [
|
||||||
|
|
|
@ -2,6 +2,5 @@ fire
|
||||||
httpx
|
httpx
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
llama-models
|
llama-models
|
||||||
pydantic==1.10.13
|
pydantic
|
||||||
pydantic_core==2.18.2
|
|
||||||
requests
|
requests
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue