Register Message and ResponseFormat

This commit is contained in:
Ashwin Bharambe 2024-12-18 10:32:25 -08:00
parent ceadaf1840
commit 12cbed1617
3 changed files with 195 additions and 335 deletions

View file

@ -25,7 +25,7 @@ from llama_models.llama3.api.datatypes import (
ToolPromptFormat,
)
from llama_models.schema_utils import json_schema_type, webmethod
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
@ -100,15 +100,18 @@ class CompletionMessage(BaseModel):
tool_calls: List[ToolCall] = Field(default_factory=list)
Message = Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
Message = register_schema(
Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
Field(discriminator="role"),
],
Field(discriminator="role"),
]
name="Message",
)
@json_schema_type
@ -187,10 +190,13 @@ class GrammarResponseFormat(BaseModel):
bnf: Dict[str, Any]
ResponseFormat = Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
ResponseFormat = register_schema(
Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
],
name="ResponseFormat",
)
@json_schema_type