diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9bdb10d95..9b8b9a8df 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,6 +30,7 @@ repos: rev: v0.9.4 hooks: - id: ruff + exclude: ^llama_stack/strong_typing/.*$ - id: ruff-format - repo: https://github.com/adamchainz/blacken-docs @@ -43,7 +44,13 @@ repos: rev: 0.5.26 hooks: - id: uv-export - args: ["--frozen", "--no-hashes", "--no-emit-project"] + args: [ + "--frozen", + "--no-hashes", + "--no-emit-project", + "--output-file=requirements.txt" + ] + files: ^pyproject\.toml$ - id: uv-sync # - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index 48109e5d8..dcbee7d2f 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -16,18 +16,6 @@ from pathlib import Path import fire import ruamel.yaml as yaml -from llama_models import schema_utils - -# We do some monkey-patching to ensure our definitions only use the minimal -# (json_schema_type, webmethod) definitions from the llama_models package. For -# generation though, we need the full definitions and implementations from the -# (json-strong-typing) package. - -from .strong_typing.schema import json_schema_type, register_schema - -schema_utils.json_schema_type = json_schema_type -schema_utils.register_schema = register_schema - from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402 from llama_stack.distribution.stack import LlamaStack # noqa: E402 diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index 0f3b99784..60cd7a242 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -10,9 +10,9 @@ import typing from dataclasses import make_dataclass from typing import Any, Dict, Set, Union -from ..strong_typing.core import JsonType -from ..strong_typing.docstring import Docstring, parse_type -from ..strong_typing.inspection import ( +from llama_stack.strong_typing.core import JsonType +from llama_stack.strong_typing.docstring import Docstring, parse_type +from llama_stack.strong_typing.inspection import ( is_generic_list, is_type_optional, is_type_union, @@ -20,15 +20,15 @@ from ..strong_typing.inspection import ( unwrap_optional_type, unwrap_union_types, ) -from ..strong_typing.name import python_type_to_name -from ..strong_typing.schema import ( +from llama_stack.strong_typing.name import python_type_to_name +from llama_stack.strong_typing.schema import ( get_schema_identifier, JsonSchemaGenerator, register_schema, Schema, SchemaOptions, ) -from ..strong_typing.serialization import json_dump_string, object_to_json +from llama_stack.strong_typing.serialization import json_dump_string, object_to_json from .operations import ( EndpointOperation, diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index bf4d35c87..88a403182 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -15,7 +15,7 @@ from llama_stack.apis.version import LLAMA_STACK_API_VERSION from termcolor import colored -from ..strong_typing.inspection import get_signature +from llama_stack.strong_typing.inspection import get_signature def split_prefix( diff --git a/docs/openapi_generator/pyopenapi/specification.py b/docs/openapi_generator/pyopenapi/specification.py index f96de58b6..9e5363b4a 100644 --- a/docs/openapi_generator/pyopenapi/specification.py +++ b/docs/openapi_generator/pyopenapi/specification.py @@ -9,7 +9,7 @@ import enum from dataclasses import dataclass from typing import Any, ClassVar, Dict, List, Optional, Union -from ..strong_typing.schema import JsonType, Schema, StrictJsonType +from llama_stack.strong_typing.schema import JsonType, Schema, StrictJsonType URL = str diff --git a/docs/openapi_generator/pyopenapi/utility.py b/docs/openapi_generator/pyopenapi/utility.py index 54f10d473..f134aab4b 100644 --- a/docs/openapi_generator/pyopenapi/utility.py +++ b/docs/openapi_generator/pyopenapi/utility.py @@ -9,7 +9,7 @@ import typing from pathlib import Path from typing import TextIO -from ..strong_typing.schema import object_to_json, StrictJsonType +from llama_stack.strong_typing.schema import object_to_json, StrictJsonType from .generator import Generator from .options import Options diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 106d34584..ccd15c3d6 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -19,7 +19,6 @@ from typing import ( runtime_checkable, ) -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent @@ -38,6 +37,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod class Attachment(BaseModel): diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py deleted file mode 100644 index 835ce4cee..000000000 --- a/llama_stack/apis/agents/event_logger.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Optional - -from llama_models.llama3.api.datatypes import ToolPromptFormat -from llama_models.llama3.api.tool_utils import ToolUtils -from termcolor import cprint - -from llama_stack.apis.agents import AgentTurnResponseEventType, StepType -from llama_stack.apis.common.content_types import ToolCallParseStatus -from llama_stack.apis.inference import ToolResponseMessage -from llama_stack.providers.utils.inference.prompt_adapter import ( - interleaved_content_as_str, -) - - -class LogEvent: - def __init__( - self, - role: Optional[str] = None, - content: str = "", - end: str = "\n", - color="white", - ): - self.role = role - self.content = content - self.color = color - self.end = "\n" if end is None else end - - def __str__(self): - if self.role is not None: - return f"{self.role}> {self.content}" - else: - return f"{self.content}" - - def print(self, flush=True): - cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) - - -EventType = AgentTurnResponseEventType - - -class EventLogger: - async def log( - self, - event_generator, - stream=True, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, - ): - previous_event_type = None - previous_step_type = None - - async for chunk in event_generator: - if not hasattr(chunk, "event"): - # Need to check for custom tool first - # since it does not produce event but instead - # a Message - if isinstance(chunk, ToolResponseMessage): - yield ( - chunk, - LogEvent(role="CustomTool", content=chunk.content, color="grey"), - ) - continue - - event = chunk.event - event_type = event.payload.event_type - if event_type in { - EventType.turn_start.value, - EventType.turn_complete.value, - }: - # Currently not logging any turn realted info - yield event, None - continue - - step_type = event.payload.step_type - # handle safety - if step_type == StepType.shield_call and event_type == EventType.step_complete.value: - violation = event.payload.step_details.violation - if not violation: - yield ( - event, - LogEvent(role=step_type, content="No Violation", color="magenta"), - ) - else: - yield ( - event, - LogEvent( - role=step_type, - content=f"{violation.metadata} {violation.user_message}", - color="red", - ), - ) - - # handle inference - if step_type == StepType.inference: - if stream: - if event_type == EventType.step_start.value: - # TODO: Currently this event is never received - yield ( - event, - LogEvent(role=step_type, content="", end="", color="yellow"), - ) - elif event_type == EventType.step_progress.value: - # HACK: if previous was not step/event was not inference's step_progress - # this is the first time we are getting model inference response - # aka equivalent to step_start for inference. Hence, - # start with "Model>". - if ( - previous_event_type != EventType.step_progress.value - and previous_step_type != StepType.inference - ): - yield ( - event, - LogEvent(role=step_type, content="", end="", color="yellow"), - ) - - delta = event.payload.delta - if delta.type == "tool_call": - if delta.parse_status == ToolCallParseStatus.succeeded: - yield ( - event, - LogEvent( - role=None, - content=delta.tool_call, - end="", - color="cyan", - ), - ) - else: - yield ( - event, - LogEvent( - role=None, - content=delta.text, - end="", - color="yellow", - ), - ) - else: - # step_complete - yield event, LogEvent(role=None, content="") - - else: - # Not streaming - if event_type == EventType.step_complete.value: - response = event.payload.step_details.model_response - if response.tool_calls: - content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format) - else: - content = response.content - yield ( - event, - LogEvent( - role=step_type, - content=content, - color="yellow", - ), - ) - - # handle tool_execution - if ( - step_type == StepType.tool_execution - and - # Only print tool calls and responses at the step_complete event - event_type == EventType.step_complete.value - ): - details = event.payload.step_details - for t in details.tool_calls: - yield ( - event, - LogEvent( - role=step_type, - content=f"Tool:{t.tool_name} Args:{t.arguments}", - color="green", - ), - ) - for r in details.tool_responses: - yield ( - event, - LogEvent( - role=step_type, - content=f"Tool:{r.tool_name} Response:{r.content}", - color="green", - ), - ) - - if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value: - details = event.payload.step_details - inserted_context = interleaved_content_as_str(details.inserted_context) - content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}" - - yield ( - event, - LogEvent( - role=step_type, - content=content, - color="cyan", - ), - ) - - previous_event_type = event_type - previous_step_type = step_type diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 413c81c5a..0fa5c78ce 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -6,7 +6,6 @@ from typing import List, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.inference import ( @@ -21,6 +20,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type diff --git a/llama_stack/apis/benchmarks/benchmarks.py b/llama_stack/apis/benchmarks/benchmarks.py index af5784bbc..91b1ca927 100644 --- a/llama_stack/apis/benchmarks/benchmarks.py +++ b/llama_stack/apis/benchmarks/benchmarks.py @@ -5,10 +5,10 @@ # the root directory of this source tree. from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.schema_utils import json_schema_type, webmethod class CommonBenchmarkFields(BaseModel): diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index e648f9a19..0d0afa894 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -7,10 +7,11 @@ from enum import Enum from typing import Annotated, List, Literal, Optional, Union -from llama_models.llama3.api.datatypes import ToolCall -from llama_models.schema_utils import json_schema_type, register_schema from pydantic import BaseModel, Field, model_validator +from llama_stack.models.llama.datatypes import ToolCall +from llama_stack.schema_utils import json_schema_type, register_schema + @json_schema_type class URL(BaseModel): diff --git a/llama_stack/apis/common/deployment_types.py b/llama_stack/apis/common/deployment_types.py index 16a5c8ad6..83eea28a2 100644 --- a/llama_stack/apis/common/deployment_types.py +++ b/llama_stack/apis/common/deployment_types.py @@ -7,10 +7,10 @@ from enum import Enum from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel from llama_stack.apis.common.content_types import URL +from llama_stack.schema_utils import json_schema_type @json_schema_type diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py index c945bd8ff..bc070017b 100644 --- a/llama_stack/apis/common/job_types.py +++ b/llama_stack/apis/common/job_types.py @@ -5,9 +5,10 @@ # the root directory of this source tree. from enum import Enum -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +from llama_stack.schema_utils import json_schema_type + @json_schema_type class Job(BaseModel): diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index b4bd1b0c6..d6c6c6919 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -7,9 +7,10 @@ from datetime import datetime from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +from llama_stack.schema_utils import json_schema_type + @json_schema_type class PostTrainingMetric(BaseModel): diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index fa9c5e92e..139ae8875 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -6,10 +6,11 @@ from typing import Literal, Union -from llama_models.schema_utils import json_schema_type, register_schema from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.schema_utils import json_schema_type, register_schema + @json_schema_type class StringType(BaseModel): diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index 2ad7aab73..d85d22876 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -6,10 +6,10 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.datasets import Dataset +from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 5e2b38697..fe9d30e2a 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -6,12 +6,12 @@ from typing import Any, Dict, List, Literal, Optional, Protocol -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.schema_utils import json_schema_type, webmethod class CommonDatasetFields(BaseModel): diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 0751b2c9b..6df93052c 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -6,7 +6,7 @@ from enum import Enum -from llama_models.schema_utils import json_schema_type +from llama_stack.schema_utils import json_schema_type @json_schema_type diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index e5c782150..e2ff4458e 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, Union -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -15,6 +14,7 @@ from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring_functions import ScoringFnParams +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @json_schema_type diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 9fccd3911..433ba3274 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -17,7 +17,13 @@ from typing import ( runtime_checkable, ) -from llama_models.llama3.api.datatypes import ( +from pydantic import BaseModel, Field, field_validator +from typing_extensions import Annotated + +from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent +from llama_stack.apis.models import Model +from llama_stack.apis.telemetry.telemetry import MetricResponseMixin +from llama_stack.models.llama.datatypes import ( BuiltinTool, SamplingParams, StopReason, @@ -25,14 +31,8 @@ from llama_models.llama3.api.datatypes import ( ToolDefinition, ToolPromptFormat, ) -from llama_models.schema_utils import json_schema_type, register_schema, webmethod -from pydantic import BaseModel, Field, field_validator -from typing_extensions import Annotated - -from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent -from llama_stack.apis.models import Model -from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod class LogProbConfig(BaseModel): diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index cd51469c1..4a647a2d9 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -6,9 +6,10 @@ from typing import List, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel +from llama_stack.schema_utils import json_schema_type, webmethod + @json_schema_type class ProviderInfo(BaseModel): diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 7e6d9854f..64b9510ea 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -7,11 +7,11 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod class CommonModelFields(BaseModel): diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 8cd2979a8..ed15c6de4 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -8,13 +8,13 @@ from datetime import datetime from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, Union -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.training_types import Checkpoint +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @json_schema_type diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 513733d1e..fd2f0292c 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -7,12 +7,12 @@ from enum import Enum from typing import Any, Dict, List, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.inference import Message from llama_stack.apis.shields import Shield from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index 5bacaaf66..960149476 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -6,10 +6,10 @@ from typing import Any, Dict, List, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams +from llama_stack.schema_utils import json_schema_type, webmethod # mapping of metric to value ScoringResultRow = Dict[str, Any] diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index fece50fbd..52508d2ec 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -16,12 +16,12 @@ from typing import ( runtime_checkable, ) -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod # Perhaps more structure can be imposed on these functions. Maybe they could be associated diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index ae316ee53..ec1179ac4 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -6,11 +6,11 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod class CommonShieldFields(BaseModel): diff --git a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py index a61fb0cf2..7b41192af 100644 --- a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +++ b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py @@ -7,10 +7,10 @@ from enum import Enum from typing import Any, Dict, List, Optional, Protocol, Union -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.inference import Message +from llama_stack.schema_utils import json_schema_type, webmethod class FilteringFunction(Enum): diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 63ae1dc73..d010a7e3b 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -17,11 +17,12 @@ from typing import ( runtime_checkable, ) -from llama_models.llama3.api.datatypes import Primitive -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.models.llama.datatypes import Primitive +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod + # Add this constant near the top of the file, after the imports DEFAULT_TTL_DAYS = 7 diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 2e6b43eb8..cff8eeefe 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -7,12 +7,12 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union -from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated, Protocol, runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @json_schema_type diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 2a407ca00..b83be127f 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -7,13 +7,13 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod from .rag_tool import RAGToolRuntime diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 1da2c128c..9a4aa322f 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -6,11 +6,11 @@ from typing import List, Literal, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 8feeaa6d4..2bbb3bce8 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -10,12 +10,12 @@ # the root directory of this source tree. from typing import Any, Dict, List, Optional, Protocol, runtime_checkable -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol +from llama_stack.schema_utils import json_schema_type, webmethod class Chunk(BaseModel): diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index 3ea534277..6b0463c10 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -16,8 +16,6 @@ from pathlib import Path from typing import Dict, List, Optional import httpx -from llama_models.datatypes import Model -from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel, ConfigDict from rich.console import Console from rich.progress import ( @@ -31,6 +29,8 @@ from rich.progress import ( from termcolor import cprint from llama_stack.cli.subcommand import Subcommand +from llama_stack.models.llama.datatypes import Model +from llama_stack.models.llama.sku_list import LlamaDownloadInfo class Download(Subcommand): @@ -454,7 +454,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): # Handle comma-separated model IDs model_ids = [model_id.strip() for model_id in args.model_id.split(",")] - from llama_models.sku_list import llama_meta_net_info, resolve_model + from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model from .model.safety_models import ( prompt_guard_download_info, diff --git a/llama_stack/cli/model/describe.py b/llama_stack/cli/model/describe.py index 3e55052c5..d8f4e035c 100644 --- a/llama_stack/cli/model/describe.py +++ b/llama_stack/cli/model/describe.py @@ -7,11 +7,11 @@ import argparse import json -from llama_models.sku_list import resolve_model from termcolor import colored from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.table import print_table +from llama_stack.models.llama.sku_list import resolve_model class ModelDescribe(Subcommand): diff --git a/llama_stack/cli/model/list.py b/llama_stack/cli/model/list.py index 9b5ebb1a5..4fe28751e 100644 --- a/llama_stack/cli/model/list.py +++ b/llama_stack/cli/model/list.py @@ -6,10 +6,9 @@ import argparse -from llama_models.sku_list import all_registered_models - from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.table import print_table +from llama_stack.models.llama.sku_list import all_registered_models class ModelList(Subcommand): diff --git a/llama_stack/cli/model/prompt_format.py b/llama_stack/cli/model/prompt_format.py index 2e1e1601e..ea9596ba5 100644 --- a/llama_stack/cli/model/prompt_format.py +++ b/llama_stack/cli/model/prompt_format.py @@ -8,9 +8,8 @@ import argparse import textwrap from io import StringIO -from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family - from llama_stack.cli.subcommand import Subcommand +from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family class ModelPromptFormat(Subcommand): diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index 2321c4615..c81783f60 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -6,11 +6,11 @@ from typing import Any, Dict, Optional -from llama_models.datatypes import CheckpointQuantizationFormat -from llama_models.llama3.api.datatypes import SamplingParams -from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel, ConfigDict, Field +from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams +from llama_stack.models.llama.sku_list import LlamaDownloadInfo + class PromptGuardModel(BaseModel): """Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed.""" diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index b1d174ede..1925b864f 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -186,33 +186,3 @@ def extract_async_iterator_type(type_hint): inner_args = get_args(arg) return inner_args[0] return None - - -async def example(model: str = None): - from llama_stack.apis.inference import Inference, UserMessage # noqa: F403 - from llama_stack.apis.inference.event_logger import EventLogger - - client_class = create_api_client_class(Inference) - client = client_class("http://localhost:5003") - - if not model: - model = "Llama3.2-3B-Instruct" - - message = UserMessage(content="hello world, write me a 2 sentence poem about the moon") - cprint(f"User>{message.content}", "green") - - stream = True - iterator = await client.chat_completion( - model=model, - messages=[message], - stream=stream, - ) - - async for log in EventLogger().log(iterator): - log.print() - - -if __name__ == "__main__": - import asyncio - - asyncio.run(example()) diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py new file mode 100644 index 000000000..a5dc9ac4a --- /dev/null +++ b/llama_stack/models/llama/datatypes.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from enum import Enum +from typing import Any, Dict, Literal, Optional, Union + +# import all for backwards compatibility +from llama_models.datatypes import * # noqa: F403 +from pydantic import BaseModel, ConfigDict, Field, field_validator +from typing_extensions import Annotated + +from llama_stack.schema_utils import json_schema_type, register_schema + +register_schema(ToolCall) + + +@json_schema_type +class ToolParamDefinition(BaseModel): + param_type: str + description: Optional[str] = None + required: Optional[bool] = True + default: Optional[Any] = None + + +@json_schema_type +class ToolDefinition(BaseModel): + tool_name: Union[BuiltinTool, str] + description: Optional[str] = None + parameters: Optional[Dict[str, ToolParamDefinition]] = None + + @field_validator("tool_name", mode="before") + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinTool(v) + except ValueError: + return v + return v + + +@json_schema_type +class GreedySamplingStrategy(BaseModel): + type: Literal["greedy"] = "greedy" + + +@json_schema_type +class TopPSamplingStrategy(BaseModel): + type: Literal["top_p"] = "top_p" + temperature: Optional[float] = Field(..., gt=0.0) + top_p: Optional[float] = 0.95 + + +@json_schema_type +class TopKSamplingStrategy(BaseModel): + type: Literal["top_k"] = "top_k" + top_k: int = Field(..., ge=1) + + +SamplingStrategy = register_schema( + Annotated[ + Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], + Field(discriminator="type"), + ], + name="SamplingStrategy", +) + + +@json_schema_type +class SamplingParams(BaseModel): + strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) + + max_tokens: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + + +class CheckpointQuantizationFormat(Enum): + # default format + bf16 = "bf16" + + # used for enabling fp8_rowwise inference, some weights are bf16 + fp8_mixed = "fp8-mixed" + + int8 = "int8" + + int4 = "int4" + + +class ModelFamily(Enum): + llama2 = "llama2" + llama3 = "llama3" + llama3_1 = "llama3_1" + llama3_2 = "llama3_2" + llama3_3 = "llama3_3" + safety = "safety" + + +class CoreModelId(Enum): + """Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)""" + + # Llama 2 family + llama2_7b = "Llama-2-7b" + llama2_13b = "Llama-2-13b" + llama2_70b = "Llama-2-70b" + llama2_7b_chat = "Llama-2-7b-chat" + llama2_13b_chat = "Llama-2-13b-chat" + llama2_70b_chat = "Llama-2-70b-chat" + + # Llama 3 family + llama3_8b = "Llama-3-8B" + llama3_70b = "Llama-3-70B" + llama3_8b_instruct = "Llama-3-8B-Instruct" + llama3_70b_instruct = "Llama-3-70B-Instruct" + + # Llama 3.1 family + llama3_1_8b = "Llama3.1-8B" + llama3_1_70b = "Llama3.1-70B" + llama3_1_405b = "Llama3.1-405B" + llama3_1_8b_instruct = "Llama3.1-8B-Instruct" + llama3_1_70b_instruct = "Llama3.1-70B-Instruct" + llama3_1_405b_instruct = "Llama3.1-405B-Instruct" + + # Llama 3.2 family + llama3_2_1b = "Llama3.2-1B" + llama3_2_3b = "Llama3.2-3B" + llama3_2_1b_instruct = "Llama3.2-1B-Instruct" + llama3_2_3b_instruct = "Llama3.2-3B-Instruct" + llama3_2_11b_vision = "Llama3.2-11B-Vision" + llama3_2_90b_vision = "Llama3.2-90B-Vision" + llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct" + llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct" + + # Llama 3.3 family + llama3_3_70b_instruct = "Llama3.3-70B-Instruct" + + # Safety models + llama_guard_3_8b = "Llama-Guard-3-8B" + llama_guard_2_8b = "Llama-Guard-2-8B" + llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision" + llama_guard_3_1b = "Llama-Guard-3-1B" + + +def is_multimodal(model_id) -> bool: + if model_id in [ + CoreModelId.llama3_2_11b_vision, + CoreModelId.llama3_2_90b_vision, + CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct, + ]: + return True + else: + return False + + +def model_family(model_id) -> ModelFamily: + if model_id in [ + CoreModelId.llama2_7b, + CoreModelId.llama2_13b, + CoreModelId.llama2_70b, + CoreModelId.llama2_7b_chat, + CoreModelId.llama2_13b_chat, + CoreModelId.llama2_70b_chat, + ]: + return ModelFamily.llama2 + elif model_id in [ + CoreModelId.llama3_8b, + CoreModelId.llama3_70b, + CoreModelId.llama3_8b_instruct, + CoreModelId.llama3_70b_instruct, + ]: + return ModelFamily.llama3 + elif model_id in [ + CoreModelId.llama3_1_8b, + CoreModelId.llama3_1_70b, + CoreModelId.llama3_1_405b, + CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_405b_instruct, + ]: + return ModelFamily.llama3_1 + elif model_id in [ + CoreModelId.llama3_2_1b, + CoreModelId.llama3_2_3b, + CoreModelId.llama3_2_1b_instruct, + CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_11b_vision, + CoreModelId.llama3_2_90b_vision, + CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct, + ]: + return ModelFamily.llama3_2 + elif model_id in [ + CoreModelId.llama3_3_70b_instruct, + ]: + return ModelFamily.llama3_3 + elif model_id in [ + CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_2_8b, + CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_1b, + ]: + return ModelFamily.safety + else: + raise ValueError(f"Unknown model family for {model_id}") + + +class Model(BaseModel): + core_model_id: CoreModelId + description: str + huggingface_repo: Optional[str] = None + recommended_sampling_params: Optional[SamplingParams] = None + arch_args: Dict[str, Any] + variant: str = "" + + quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 + pth_file_count: int + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + + # silence pydantic until we remove the `model_` fields + model_config = ConfigDict(protected_namespaces=()) + + @property + def model_family(self) -> ModelFamily: + return model_family(self.core_model_id) + + # The SKU is uniquely identified by (model_id, variant) combo + def descriptor(self, shorten_default_variant: bool = True) -> str: + if not self.variant: + return self.core_model_id.value + return f"{self.core_model_id.value}:{self.variant}" + + @property + def is_instruct_model(self) -> bool: + return "instruct" in self.id.name + + # Featured models are shown in the non-exhaustive model list + @property + def is_featured(self) -> bool: + return self.model_family in [ + ModelFamily.llama3_1, + ModelFamily.llama3_2, + ModelFamily.llama3_3, + ModelFamily.safety, + ] + + @property + def max_seq_length(self) -> int: + if self.model_family == ModelFamily.llama2: + return 4096 + elif self.core_model_id == CoreModelId.llama_guard_2_8b: + return 4096 + elif self.model_family == ModelFamily.llama3: + return 8192 + elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]: + return 131072 + elif self.model_family == ModelFamily.llama3_2: + if self.quantization_format == CheckpointQuantizationFormat.int4: + return 8192 + return 131072 + elif self.core_model_id in [ + CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_1b, + ]: + return 131072 + else: + raise ValueError(f"Unknown max_seq_len for {self.core_model_id}") diff --git a/llama_stack/models/llama/llama3/dog.jpg b/llama_stack/models/llama/llama3/dog.jpg new file mode 100644 index 000000000..f9a3a8057 Binary files /dev/null and b/llama_stack/models/llama/llama3/dog.jpg differ diff --git a/llama_stack/models/llama/llama3/interface.py b/llama_stack/models/llama/llama3/interface.py new file mode 100644 index 000000000..bc42228a5 --- /dev/null +++ b/llama_stack/models/llama/llama3/interface.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from pathlib import Path +from typing import List, Optional + +from llama_models.datatypes import ( + BuiltinTool, + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer +from termcolor import colored + +from llama_stack.models.llama.datatypes import ToolDefinition + +from . import template_data +from .prompt_templates import ( + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + SystemDefaultGenerator, + ToolResponseGenerator, +) + +THIS_DIR = Path(__file__).parent + + +class Template: + def __init__( + self, + role, + template_name, + data_provider=None, + notes=None, + ): + self.role = role + self.template_name = template_name + self.data_provider = data_provider or "" + self._notes = notes or "" + + @property + def notes(self): + default = "↵ represents newline" + notes = default + if self._notes: + notes += "\n" + notes += self._notes + return notes + + +TEMPLATES = [ + Template( + "user", + "user-default", + "user_default", + ), + Template( + "user", + "user-images", + "user_images", + ), + Template("user", "user-interleaved-images", "user_interleaved_images"), + Template( + "assistant", + "assistant-builtin-tool-call", + "assistant_builtin_tool_call", + "Notice <|python_tag|>", + ), + Template( + "assistant", + "assistant-custom-tool-call", + "assistant_custom_tool_call", + "Notice format", + ), + Template( + "assistant", + "assistant-default", + "assistant_default", + ), + Template( + "system", + "system-builtin-and-custom-tools", + "system_message_builtin_and_custom_tools", + ), + Template( + "system", + "system-builtin-tools-only", + "system_message_builtin_tools_only", + ), + Template( + "system", + "system-custom-tools-only", + "system_message_custom_tools_only", + ), + Template( + "system", + "system-default", + "system_default", + ), + Template( + "tool", + "tool-success", + "tool_success", + "Note ipython header and [stdout]", + ), + Template( + "tool", + "tool-failure", + "tool_failure", + "Note ipython header and [stderr]", + ), +] + + +class LLama31Interface: + def __init__(self, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json): + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) + self.tool_prompt_format = tool_prompt_format + + def get_tokens(self, messages: List[RawMessage]) -> List[int]: + model_input = self.formatter.encode_dialog_prompt( + messages, + self.tool_prompt_format, + ) + return model_input.tokens + + def tool_response_messages(self, *args, **kwargs): + template = ToolResponseGenerator().gen(*args, **kwargs) + return [ + RawMessage( + role="tool", + content=template.render(), + ) + ] + + def system_messages( + self, + builtin_tools: List[BuiltinTool], + custom_tools: List[ToolDefinition], + instruction: Optional[str] = None, + ) -> List[RawMessage]: + messages = [] + + default_gen = SystemDefaultGenerator() + default_template = default_gen.gen() + + sys_content = "" + + tool_template = None + if builtin_tools or custom_tools: + tool_gen = BuiltinToolGenerator() + tool_template = tool_gen.gen(builtin_tools + custom_tools) + + sys_content += tool_template.render() + sys_content += "\n" + + sys_content += default_template.render() + + if instruction: + sys_content += "\n\n" + sys_content += instruction + + sys_content += "\n" + messages.append(RawMessage(role="system", content=sys_content)) + + if custom_tools: + if self.tool_prompt_format == ToolPromptFormat.json: + tool_gen = JsonCustomToolGenerator() + elif self.tool_prompt_format == ToolPromptFormat.function_tag: + tool_gen = FunctionTagCustomToolGenerator() + else: + raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}") + + custom_template = tool_gen.gen(custom_tools) + messages.append(RawMessage(role="user", content=custom_template.render())) + + return messages + + def assistant_response_messages( + self, + content: str, + stop_reason: StopReason, + tool_call: Optional[ToolCall] = None, + ) -> List[RawMessage]: + tool_calls = [] + if tool_call: + tool_calls.append(tool_call) + return [ + RawMessage( + role="assistant", + content=content, + tool_calls=tool_calls, + stop_reason=stop_reason, + ) + ] + + def user_message(self, content: str) -> List[RawMessage]: + return [RawMessage(role="user", content=content)] + + def display_message_as_tokens(self, message: RawMessage) -> None: + """Util to print tokenized string to shell""" + tokens = self.formatter.encode_message(message, self.tool_prompt_format) + on_colors = [ + "on_red", + "on_green", + "on_yellow", + "on_blue", + "on_magenta", + "on_cyan", + ] + for i, t in enumerate(tokens): + on_col = on_colors[i % len(on_colors)] + print(colored(self.tokenizer.decode([t]), "white", on_col), end="") + print("\n", end="") + + +def list_jinja_templates() -> List[Template]: + return TEMPLATES + + +def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat): + by_name = {t.template_name: t for t in TEMPLATES} + if name not in by_name: + raise ValueError(f"No template found for `{name}`") + + template = by_name[name] + interface = LLama31Interface(tool_prompt_format) + + data_func = getattr(template_data, template.data_provider) + if template.role == "system": + messages = interface.system_messages(**data_func()) + elif template.role == "tool": + messages = interface.tool_response_messages(**data_func()) + elif template.role == "assistant": + messages = interface.assistant_response_messages(**data_func()) + elif template.role == "user": + messages = interface.user_message(**data_func()) + + tokens = interface.get_tokens(messages) + special_tokens = list(interface.tokenizer.special_tokens.values()) + tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens] + return template, tokens diff --git a/llama_stack/models/llama/llama3/pasta.jpeg b/llama_stack/models/llama/llama3/pasta.jpeg new file mode 100644 index 000000000..e8299321c Binary files /dev/null and b/llama_stack/models/llama/llama3/pasta.jpeg differ diff --git a/llama_stack/models/llama/llama3/prompt_templates/__init__.py b/llama_stack/models/llama/llama3/prompt_templates/__init__.py new file mode 100644 index 000000000..4eed54d12 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from .base import PromptTemplate, PromptTemplateGeneratorBase # noqa: F401 +from .system_prompts import ( # noqa: F401 + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + PythonListCustomToolGenerator, + SystemDefaultGenerator, +) +from .tool_response import ToolResponseGenerator # noqa: F401 diff --git a/llama_stack/models/llama/llama3/prompt_templates/base.py b/llama_stack/models/llama/llama3/prompt_templates/base.py new file mode 100644 index 000000000..bff2a21e1 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/base.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from dataclasses import dataclass +from typing import Any, Dict, List + +from jinja2 import Template + + +@dataclass +class PromptTemplate: + template: str + data: Dict[str, Any] + + def render(self): + template = Template(self.template) + return template.render(self.data) + + +class PromptTemplateGeneratorBase: + """ + Base class for prompt template generators. + """ + + def gen(self, *args, **kwargs) -> PromptTemplate: + raise NotImplementedError() + + def data_examples(self) -> List[Any]: + raise NotImplementedError() diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py new file mode 100644 index 000000000..27b1a3502 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from datetime import datetime +from typing import Any, List, Optional + +from llama_models.datatypes import ( + BuiltinTool, +) + +from llama_stack.models.llama.datatypes import ( + ToolDefinition, + ToolParamDefinition, +) + +from .base import PromptTemplate, PromptTemplateGeneratorBase + + +class SystemDefaultGenerator(PromptTemplateGeneratorBase): + def gen(self, *args, **kwargs) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Cutting Knowledge Date: December 2023 + Today Date: {{ today }} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + {"today": datetime.now().strftime("%d %B %Y")}, + ) + + def data_examples(self) -> List[Any]: + return [None] + + +class BuiltinToolGenerator(PromptTemplateGeneratorBase): + def _tool_breakdown(self, tools: List[ToolDefinition]): + builtin_tools, custom_tools = [], [] + for dfn in tools: + if isinstance(dfn.tool_name, BuiltinTool): + builtin_tools.append(dfn) + else: + custom_tools.append(dfn) + + return builtin_tools, custom_tools + + def gen(self, tools: List[ToolDefinition]) -> PromptTemplate: + builtin_tools, custom_tools = self._tool_breakdown(tools) + template_str = textwrap.dedent( + """ + {% if builtin_tools or custom_tools -%} + Environment: ipython + {% endif -%} + {% set builtin_tools = builtin_tools | reject('equalto', 'code_interpreter') | list -%} + {% if builtin_tools -%} + Tools: {{ builtin_tools | join(", ") | trim -}} + {% endif %} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + { + "builtin_tools": [t.tool_name.value for t in builtin_tools], + "custom_tools": custom_tools, + }, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + # builtin tools + [ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition(tool_name=BuiltinTool.brave_search), + ToolDefinition(tool_name=BuiltinTool.wolfram_alpha), + ], + # only code interpretor + [ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ], + ] + + +class JsonCustomToolGenerator(PromptTemplateGeneratorBase): + def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + {% for t in custom_tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "type": "function", + "function": { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "object", + "properties": [ + {%- for name, param in tparams.items() %} + { + "{{name}}": { + "type": "object", + "description": "{{param.description}}" + } + }{% if not loop.last %},{% endif %} + {%- endfor %} + ], + "required": {{ required_params | tojson }} + } + } + } + {% endfor %} + Return function calls in JSON format. + """ + ) + + return PromptTemplate( + template_str.lstrip("\n"), + {"custom_tools": [t.model_dump() for t in custom_tools]}, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="trending_songs", + description="Returns the trending songs on a Music site", + parameters={ + "n": ToolParamDefinition( + param_type="int", + description="The number of songs to return", + required=True, + ), + "genre": ToolParamDefinition( + param_type="str", + description="The genre of the songs to return", + required=False, + ), + }, + ), + ] + ] + + +class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): + def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + You have access to the following functions: + + {% for t in custom_tools %} + {#- manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set modified_params = t.parameters.copy() -%} + {%- for key, value in modified_params.items() -%} + {%- if 'default' in value -%} + {%- set _ = value.pop('default', None) -%} + {%- endif -%} + {%- endfor -%} + {%- set tparams = modified_params | tojson -%} + Use the function '{{ tname }}' to '{{ tdesc }}': + {"name": "{{tname}}", "description": "{{tdesc}}", "parameters": {{tparams}}} + + {% endfor -%} + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + {"custom_tools": [t.model_dump() for t in custom_tools]}, + ) + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="trending_songs", + description="Returns the trending songs on a Music site", + parameters={ + "n": ToolParamDefinition( + param_type="int", + description="The number of songs to return", + required=True, + ), + "genre": ToolParamDefinition( + param_type="str", + description="The genre of the songs to return", + required=False, + ), + }, + ), + ] + ] + + +class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 + DEFAULT_PROMPT = textwrap.dedent( + """ + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, point it out. If the given question lacks the parameters required by the function, + also point it out. You should only return the function call in tools call sections. + + {{ function_description }} + """.strip("\n") + ) + + def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate: + system_prompt = system_prompt or self.DEFAULT_PROMPT + return PromptTemplate( + system_prompt, + {"function_description": self._gen_function_description(custom_tools)}, + ) + + def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + {% for t in tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "dict", + "required": {{ required_params | tojson }}, + "properties": { + {%- for name, param in tparams.items() %} + "{{name}}": { + "type": "{{param.param_type}}", + "description": "{{param.description}}"{% if param.default %}, + "default": "{{param.default}}"{% endif %} + }{% if not loop.last %},{% endif %} + {%- endfor %} + } + } + }{% if not loop.last %}, + {% endif -%} + {%- endfor %} + ] + """ + ) + return PromptTemplate( + template_str.strip("\n"), + {"tools": [t.model_dump() for t in custom_tools]}, + ).render() + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="get_weather", + description="Get weather info for places", + parameters={ + "city": ToolParamDefinition( + param_type="string", + description="The name of the city to get the weather for", + required=True, + ), + "metric": ToolParamDefinition( + param_type="string", + description="The metric for weather. Options are: celsius, fahrenheit", + required=False, + default="celsius", + ), + }, + ), + ] + ] diff --git a/llama_stack/models/llama/llama3/prompt_templates/tool_response.py b/llama_stack/models/llama/llama3/prompt_templates/tool_response.py new file mode 100644 index 000000000..3df4dac14 --- /dev/null +++ b/llama_stack/models/llama/llama3/prompt_templates/tool_response.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import Optional + +from .base import PromptTemplate, PromptTemplateGeneratorBase + + +class ToolResponseGenerator(PromptTemplateGeneratorBase): + def gen( + self, + status: str, + stdout: Optional[str] = None, + stderr: Optional[str] = None, + ): + assert status in [ + "success", + "failure", + ], f"status must be 'success' or 'failure'; Got: {status}" + template_str = textwrap.dedent( + """ + {% if status == "success" %}completed{% else %}failed{% endif %} + {%- if stdout %} + [stdout]{{ stdout }}[/stdout] + {%- endif -%} + {%- if stderr %} + [stderr]{{ stderr }}[/stderr] + {%- endif -%} + """ + ) + return PromptTemplate( + template_str.lstrip("\n"), + { + "status": status, + "stdout": stdout, + "stderr": stderr, + }, + ) + + def data_examples(self): + return [ + # success + { + "status": "success", + "stdout": '{"results":["something something"]}', + }, + # failure + { + "status": "failure", + "stderr": "brave_search encounter an error: could not communicate with api.brave.com", + }, + ] diff --git a/llama_stack/models/llama/llama3/template_data.py b/llama_stack/models/llama/llama3/template_data.py new file mode 100644 index 000000000..620816ffc --- /dev/null +++ b/llama_stack/models/llama/llama3/template_data.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from llama_models.datatypes import ( + BuiltinTool, + StopReason, + ToolCall, +) + +from .prompt_templates import ( + BuiltinToolGenerator, + JsonCustomToolGenerator, + ToolResponseGenerator, +) + +INSTRUCTION = "You are a helpful assistant." + + +def system_message_builtin_tools_only(): + return { + "builtin_tools": BuiltinToolGenerator().data_examples()[0], + "custom_tools": [], + "instruction": INSTRUCTION, + } + + +def system_message_builtin_code_only(): + return { + "builtin_tools": BuiltinToolGenerator().data_examples()[1], + "custom_tools": [], + "instruction": "", + } + + +def system_message_custom_tools_only(): + return { + "builtin_tools": [], + "custom_tools": JsonCustomToolGenerator().data_examples()[0], + "instruction": INSTRUCTION, + } + + +def system_message_builtin_and_custom_tools(): + return { + "builtin_tools": BuiltinToolGenerator().data_examples()[0], + "custom_tools": JsonCustomToolGenerator().data_examples()[0], + "instruction": INSTRUCTION, + } + + +def system_default(): + return { + "builtin_tools": [], + "custom_tools": [], + "instruction": INSTRUCTION, + } + + +def tool_success(): + return ToolResponseGenerator().data_examples()[0] + + +def tool_failure(): + return ToolResponseGenerator().data_examples()[1] + + +def assistant_builtin_tool_call(): + return { + "content": "", + "tool_call": ToolCall( + call_id="uuid", + tool_name=BuiltinTool.brave_search, + arguments={ + "query": "Who won NBA in 2024?", + }, + ), + "stop_reason": StopReason.end_of_message, + } + + +def assistant_custom_tool_call(): + return { + "content": "", + "tool_call": ToolCall( + call_id="uuid", + tool_name="trending_songs", + arguments={"country": "US", "n": 10}, + ), + "stop_reason": StopReason.end_of_turn, + } + + +def assistant_default(): + return { + "content": "Hi, I am a helpful assistant. What can I help you with today?", + "tool_call": None, + "stop_reason": StopReason.end_of_turn, + } + + +def user_default(): + return {"content": "Please tell me how to plan a trip to New York"} + + +def user_images(): + return {"content": "<|image|><|image|>What do these images depict?"} + + +def user_interleaved_images(): + return {"content": "<|image|>Describe the image in one sentence.<|image|>Write a haiku about these images"} diff --git a/llama_stack/models/llama/llama3/test_system_prompts.py b/llama_stack/models/llama/llama3/test_system_prompts.py new file mode 100644 index 000000000..b47b1ff2d --- /dev/null +++ b/llama_stack/models/llama/llama3/test_system_prompts.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +import unittest +from datetime import datetime + +from .prompt_templates import ( + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + PythonListCustomToolGenerator, + SystemDefaultGenerator, +) + + +class PromptTemplateTests(unittest.TestCase): + def check_generator_output(self, generator, expected_text): + example = generator.data_examples()[0] + + pt = generator.gen(example) + text = pt.render() + # print(text) # debugging + assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}" + + def test_system_default(self): + generator = SystemDefaultGenerator() + today = datetime.now().strftime("%d %B %Y") + expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}" + self.check_generator_output(generator, expected_text) + + def test_system_builtin_only(self): + generator = BuiltinToolGenerator() + expected_text = textwrap.dedent( + """ + Environment: ipython + Tools: brave_search, wolfram_alpha + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_system_custom_only(self): + self.maxDiff = None + generator = JsonCustomToolGenerator() + expected_text = textwrap.dedent( + """ + Answer the user's question by making use of the following functions if needed. + If none of the function can be used, please say so. + Here is a list of functions in JSON format: + { + "type": "function", + "function": { + "name": "trending_songs", + "description": "Returns the trending songs on a Music site", + "parameters": { + "type": "object", + "properties": [ + { + "n": { + "type": "object", + "description": "The number of songs to return" + } + }, + { + "genre": { + "type": "object", + "description": "The genre of the songs to return" + } + } + ], + "required": ["n"] + } + } + } + + Return function calls in JSON format. + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_system_custom_function_tag(self): + self.maxDiff = None + generator = FunctionTagCustomToolGenerator() + expected_text = textwrap.dedent( + """ + You have access to the following functions: + + Use the function 'trending_songs' to 'Returns the trending songs on a Music site': + {"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}} + + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_llama_3_2_system_zero_shot(self): + generator = PythonListCustomToolGenerator() + expected_text = textwrap.dedent( + """ + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, point it out. If the given question lacks the parameters required by the function, + also point it out. You should only return the function call in tools call sections. + + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + { + "name": "get_weather", + "description": "Get weather info for places", + "parameters": { + "type": "dict", + "required": ["city"], + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get the weather for" + }, + "metric": { + "type": "string", + "description": "The metric for weather. Options are: celsius, fahrenheit", + "default": "celsius" + } + } + } + } + ] + """ + ) + self.check_generator_output(generator, expected_text.strip("\n")) + + def test_llama_3_2_provided_system_prompt(self): + generator = PythonListCustomToolGenerator() + expected_text = textwrap.dedent( + """ + Overriding message. + + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + { + "name": "get_weather", + "description": "Get weather info for places", + "parameters": { + "type": "dict", + "required": ["city"], + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get the weather for" + }, + "metric": { + "type": "string", + "description": "The metric for weather. Options are: celsius, fahrenheit", + "default": "celsius" + } + } + } + } + ]""" + ) + user_system_prompt = textwrap.dedent( + """ + Overriding message. + + {{ function_description }} + """ + ) + example = generator.data_examples()[0] + + pt = generator.gen(example, user_system_prompt) + text = pt.render() + assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}" diff --git a/llama_stack/models/llama/llama3_1/__init__.py b/llama_stack/models/llama/llama3_1/__init__.py new file mode 100644 index 000000000..38ee47d66 --- /dev/null +++ b/llama_stack/models/llama/llama3_1/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. diff --git a/llama_stack/models/llama/llama3_1/prompts.py b/llama_stack/models/llama/llama3_1/prompts.py new file mode 100644 index 000000000..edbce3bc0 --- /dev/null +++ b/llama_stack/models/llama/llama3_1/prompts.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import List + +from llama_models.datatypes import ( + BuiltinTool, + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) + +from ..prompt_format import ( + # llama3_1_e2e_tool_call_dialog, + TextCompletionContent, + UseCase, + llama3_1_builtin_tool_call_dialog, + llama3_1_custom_tool_call_dialog, +) + + +def wolfram_alpha_response(): + return textwrap.dedent( + """ + { + "queryresult": { + "success": true, + "inputstring": "100th decimal of pi", + "pods": [ + { + "title": "Input interpretation", + "subpods": [ + { + "title": "", + "plaintext": "100th digit | \u03c0" + } + ] + }, + { + "title": "Nearby digits", + "subpods": [ + { + "title": "", + "plaintext": "...86208998628034825342117067982148086513282306647093..." + } + ] + }, + { + "title": "Result", + "primary": true, + "subpods": [ + { + "title": "", + "plaintext": "7" + } + ] + } + ] + } + } + """ + ) + + +def usecases() -> List[UseCase | str]: + return [ + textwrap.dedent( + """ + # Llama 3.1 - Prompt Formats + ## Tokens + Here is a list of special tokens that are supported by Llama 3.1: + - `<|begin_of_text|>`: Specifies the start of the prompt + - `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models. + - `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch. + - `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool] + - `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool. + - `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios: + - at the end of a direct interaction between the model and the user + - at the end of multiple interactions between the model and any available tools + This token signals to the executor that the model has finished generating a response. + - `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call. + """ + ), + textwrap.dedent( + """ + There are 4 different roles that are supported by Llama 3.1 + - `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively. + - `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model. + - `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".) + - `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts. + """ + ), + UseCase( + title="Llama 3.1 Base Model", + description="Text completion for Llama 3.1 base model uses this format.", + dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")], + notes="Note start special tag", + ), + "## Llama 3.1 Instruct Model", + UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage( + role="user", + content="Answer who are you in the form of jeopardy?", + ), + ] + ], + notes="", + ), + "## Tool Calling Formats", + textwrap.dedent( + """ + The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt: + - Brave Search: Tool call to perform web searches. + - Wolfram Alpha: Tool call to perform complex mathematical calculations. + - Code Interpreter: Enables the model to output python code. + """ + ), + UseCase( + title="Builtin Tool Calling", + description=textwrap.dedent( + """ + Here is an example of a conversation using brave search + """ + ), + dialogs=[llama3_1_builtin_tool_call_dialog()], + notes=textwrap.dedent( + """ + - Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model. + - The message body of the assistant response starts with a special tag <|python_tag|> + - As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call. + - The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha` + """ + ), + ), + UseCase( + title="Builtin Code Interpreter", + description="Here is an actual example of model responding with code", + dialogs=[ + [ + RawMessage(role="system", content="Environment: ipython"), + RawMessage( + role="user", + content="Write code to check if number is prime, use that to see if the number 7 is prime", + ), + ], + ], + notes=textwrap.dedent( + """ + - Model starts with <|python_tag|> and continues writing python code that it needs to be executed + - No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it. + """ + ), + ), + UseCase( + title="Built-in tools full interaction", + description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.", + dialogs=[ + [ + RawMessage( + role="system", + content="Environment: ipython\nTools: brave_search, wolfram_alpha\n", + ), + RawMessage(role="user", content="What is the 100th decimal of pi?"), + RawMessage( + role="assistant", + content="", + stop_reason=StopReason.end_of_message, + tool_calls=[ + ToolCall( + call_id="tool_call_id", + tool_name=BuiltinTool.wolfram_alpha, + arguments={"query": "100th decimal of pi"}, + ) + ], + ), + RawMessage( + role="tool", + content=wolfram_alpha_response(), + ), + ], + ], + notes=textwrap.dedent( + """ + - Note the `<|python_tag|>` in the assistant response. + - Role is `tool` for the wolfram alpha response that is passed back to the model. + - Final message from assistant has <|eot_id|> tag. + """ + ), + ), + "## Zero shot tool calling", + UseCase( + title="JSON based tool calling", + description=textwrap.dedent( + """ + Llama models can now output custom tool calls from a single message to allow easier tool calling. + The following prompts provide an example of how custom tools can be called from the output of the model. + It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog()], + notes=textwrap.dedent( + """ + - JSON format for providing tools needs name, description and parameters + - Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt + - Instructions for tools added as a user message + - Only single tool calls are supported as of now + """ + ), + ), + # FIXME: This is not working yet as expected + # UseCase( + # title="E2E tool call example", + # description=textwrap.dedent( + # """ + # Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model. + # """ + # ), + # dialogs=[ + # llama3_1_e2e_tool_call_dialog( + # tool_prompt_format=ToolPromptFormat.function_tag + # ) + # ], + # notes="", + # ), + "## Example of a user defined tool calling", + UseCase( + title="`` based tool calling", + description=textwrap.dedent( + """ + Here is an example of how you could also write custom instructions for model to do zero shot tool calling. + In this example, we define a custom tool calling format using the `` tag. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)], + notes=textwrap.dedent( + """ + - In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>` + - Instructions for tools added as a user message + """ + ), + ), + ] diff --git a/llama_stack/models/llama/llama3_2/__init__.py b/llama_stack/models/llama/llama3_2/__init__.py new file mode 100644 index 000000000..38ee47d66 --- /dev/null +++ b/llama_stack/models/llama/llama3_2/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. diff --git a/llama_stack/models/llama/llama3_2/prompts_text.py b/llama_stack/models/llama/llama3_2/prompts_text.py new file mode 100644 index 000000000..29557f4be --- /dev/null +++ b/llama_stack/models/llama/llama3_2/prompts_text.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. +import json +import textwrap + +from llama_models.datatypes import ( + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) + +from ..prompt_format import ( + TextCompletionContent, + UseCase, + llama3_1_builtin_code_interpreter_dialog, +) + + +def user_tool_call(): + content = textwrap.dedent( + """ + Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request? + Here is a list of functions in JSON format that you can invoke: + [ + { + "name": "get_user_info", + "description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.", + "parameters": { + "type": "dict", + "required": [ + "user_id" + ], + "properties": { + "user_id": { + "type": "integer", + "description": "The unique identifier of the user. It is used to fetch the specific user details from the database." + }, + "special": { + "type": "string", + "description": "Any special information or parameters that need to be considered while fetching user details.", + "default": "none" + } + } + } + } + ] + + Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)] + + NO other text MUST be included. + """ + ) + return content.strip() + + +def system_tool_call(): + content = textwrap.dedent( + """ + You are an expert in composing functions. You are given a question and a set of possible functions. + Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + If none of the function can be used, point it out. If the given question lacks the parameters required by the function, + also point it out. You should only return the function call in tools call sections. + + If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] + You SHOULD NOT include any other text in the response. + + Here is a list of functions in JSON format that you can invoke. + + [ + { + "name": "get_weather", + "description": "Get weather info for places", + "parameters": { + "type": "dict", + "required": [ + "city" + ], + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get the weather for" + }, + "metric": { + "type": "string", + "description": "The metric for weather. Options are: celsius, fahrenheit", + "default": "celsius" + } + } + } + } + ] + """ + ) + return content.strip() + + +def usecases(): + return [ + UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage(role="user", content="Who are you?"), + ] + ], + notes="This format is unchanged from Llama3.1", + ), + UseCase( + title="Zero shot function calling", + description=textwrap.dedent( + """ + For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling. + This new format is designed to be more flexible and powerful than the previous format. + All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls. + It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `` tag that were defined in Llama3.1. + Here is an example for the same, + """ + ), + dialogs=[ + # Zero shot tool calls as system message + [ + RawMessage(role="system", content=system_tool_call()), + RawMessage(role="user", content="What is the weather in SF and Seattle?"), + ], + ], + notes=textwrap.dedent( + """ + - The output supports multiple tool calls natively + - JSON format for defining the functions in the system prompt is similar to Llama3.1 + """ + ), + ), + UseCase( + title="Zero shot function calling with user message", + description=textwrap.dedent( + """ + While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message. + """ + ), + dialogs=[ + # Zero shot tool call as user message + [ + RawMessage(role="user", content=user_tool_call()), + ], + ], + notes=textwrap.dedent( + """ + - The tool call format for the model is the same whether your function calls are provided in the system or user message. + - While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls. + """ + ), + ), + UseCase( + title="Code Interpreter", + description=textwrap.dedent( + """ + Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family. + Here is an example, + """ + ), + dialogs=[llama3_1_builtin_code_interpreter_dialog()], + notes=textwrap.dedent( + """ + - Note `Environment: ipython` in the system prompt. + - Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>` + """ + ), + ), + UseCase( + title="Zero shot function calling E2E format", + description=textwrap.dedent( + """ + Here is an example of the e2e cycle of tool calls with the model in a muti-step way. + """ + ), + dialogs=[ + [ + RawMessage(role="system", content=system_tool_call()), + RawMessage(role="user", content="What is the weather in SF?"), + RawMessage( + role="assistant", + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + call_id="cc", + tool_name="get_weather", + arguments={ + "city": "San Francisco", + "metric": "celsius", + }, + ) + ], + ), + RawMessage( + role="tool", + content=json.dumps("25 C"), + ), + ], + ], + notes=textwrap.dedent( + """ + - The output of the function call is provided back to the model as a tool response ( in json format ). + - Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response. + - The model finally summarizes the information from the tool response and returns the result to the user. + """ + ), + tool_prompt_format=ToolPromptFormat.python_list, + ), + UseCase( + title="Prompt format for base models", + description=textwrap.dedent( + """ + For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows + """ + ), + dialogs=[ + TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"), + ], + notes="Same as Llama3.1", + ), + ] diff --git a/llama_stack/models/llama/llama3_2/prompts_vision.py b/llama_stack/models/llama/llama3_2/prompts_vision.py new file mode 100644 index 000000000..c3cfe5e7b --- /dev/null +++ b/llama_stack/models/llama/llama3_2/prompts_vision.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from pathlib import Path + +from llama_models.datatypes import ( + RawMediaItem, + RawMessage, + RawTextItem, +) + +from ..prompt_format import ( + TextCompletionContent, + UseCase, + llama3_1_builtin_tool_call_dialog, + # llama3_1_builtin_tool_call_with_image_dialog, + llama3_2_user_assistant_conversation, +) + + +def usecases(): + this_dir = Path(__file__).parent.parent.resolve() + with open(this_dir / "scripts/resources/dog.jpg", "rb") as f: + img = f.read() + + return [ + llama3_2_user_assistant_conversation(), + UseCase( + title="User and assistant conversation with Images", + description="This example shows how to pass and image to the model as part of the messages.", + dialogs=[ + [ + RawMessage( + role="user", + content=[ + RawMediaItem(data=img), + RawTextItem(text="Describe this image in two sentences"), + ], + ) + ], + ], + notes=textwrap.dedent( + """ + - The `<|image|>` tag is used to indicate presence of the image + - The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder + ![Image](mm-model.png) + - Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens + - The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body + - We recommend using a single image in one prompt + """ + ), + ), + UseCase( + title="Builtin and Zero Shot Tool Calling", + description=textwrap.dedent( + """ + Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only. + Use `Environment: ipython` to enable tools. + Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools. + The same builtin tools as Llama3.1 are available, + - code_interpreter (for executing python code) + - brave_search (to search the web) + - wolfram_alpha (for querying wolfram alpha for mathematical questions) + """, + ), + dialogs=[llama3_1_builtin_tool_call_dialog()], + notes=textwrap.dedent( + """ + - Note the `<|python_tag|>` before `brave_search` function call. + - The `<|eom_id|>` tag is used to indicate the end of the message. + - Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`. + - Tool Calling does NOT work with images in the prompt as of now. + """ + ), + ), + # UseCase( + # title="Tool Calling for vision models", + # description=textwrap.dedent( + # """ + # While Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only, + # they are not able to do tool calling when prompt contains image inputs (along with text). + # The recommended way would be to separate out the image understanding from the tool calling in successive prompts. + # Here is an example of how that could be done, + # """, + # ), + # dialogs=[llama3_1_builtin_tool_call_with_image_dialog()], + # notes=textwrap.dedent( + # """ + # - Instead of a single prompt (image understanding + tool call), we split into two prompts to achieve the same result. + # """ + # ), + # ), + UseCase( + title="Prompt format for base models", + description=textwrap.dedent( + """ + For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows + """ + ), + dialogs=[ + TextCompletionContent(content="The color of the sky is blue but sometimes it can also be"), + ], + notes="- Same as Llama3.1", + ), + UseCase( + title="Prompt format for base models with Image", + description=textwrap.dedent( + """ + For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image, + """ + ), + dialogs=[ + TextCompletionContent( + content=[ + RawMediaItem(data=img), + RawTextItem(text="If I had to write a haiku for this one"), + ] + ), + ], + notes="- Note the placement of the special tags <|begin_of_text|> and <|image|>", + ), + ] diff --git a/llama_stack/models/llama/llama3_3/prompts.py b/llama_stack/models/llama/llama3_3/prompts.py new file mode 100644 index 000000000..14fd86853 --- /dev/null +++ b/llama_stack/models/llama/llama3_3/prompts.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import List + +from llama_models.datatypes import ( + BuiltinTool, + RawMessage, + StopReason, + ToolCall, + ToolPromptFormat, +) + +from ..prompt_format import ( + # llama3_1_e2e_tool_call_dialog, + TextCompletionContent, + UseCase, + llama3_1_builtin_tool_call_dialog, + llama3_1_custom_tool_call_dialog, +) + + +def wolfram_alpha_response(): + return textwrap.dedent( + """ + { + "queryresult": { + "success": true, + "inputstring": "100th decimal of pi", + "pods": [ + { + "title": "Input interpretation", + "subpods": [ + { + "title": "", + "plaintext": "100th digit | \u03c0" + } + ] + }, + { + "title": "Nearby digits", + "subpods": [ + { + "title": "", + "plaintext": "...86208998628034825342117067982148086513282306647093..." + } + ] + }, + { + "title": "Result", + "primary": true, + "subpods": [ + { + "title": "", + "plaintext": "7" + } + ] + } + ] + } + } + """ + ) + + +def usecases() -> List[UseCase | str]: + return [ + textwrap.dedent( + """ + # Llama 3.1 - Prompt Formats + ## Tokens + Here is a list of special tokens that are supported by Llama 3.1: + - `<|begin_of_text|>`: Specifies the start of the prompt + - `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models. + - `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch. + - `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and tool] + - `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool. + - `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios: + - at the end of a direct interaction between the model and the user + - at the end of multiple interactions between the model and any available tools + This token signals to the executor that the model has finished generating a response. + - `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call. + """ + ), + textwrap.dedent( + """ + There are 4 different roles that are supported by Llama 3.1 + - `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively. + - `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model. + - `tool`: A new role introduced in Llama 3.1. This role is used to mark messages with the output of a tool call when sent back to the model from the executor. (The actual token used by the model for this role is "ipython".) + - `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts. + """ + ), + UseCase( + title="Llama 3.1 Base Model", + description="Text completion for Llama 3.1 base model uses this format.", + dialogs=[TextCompletionContent(content="Color of sky is blue but sometimes can also be")], + notes="Note start special tag", + ), + "## Llama 3.1 Instruct Model", + UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage( + role="user", + content="Answer who are you in the form of jeopardy?", + ), + ] + ], + notes="", + ), + "## Tool Calling Formats", + textwrap.dedent( + """ + The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt: + - Brave Search: Tool call to perform web searches. + - Wolfram Alpha: Tool call to perform complex mathematical calculations. + - Code Interpreter: Enables the model to output python code. + """ + ), + UseCase( + title="Builtin Tool Calling", + description=textwrap.dedent( + """ + Here is an example of a conversation using brave search + """ + ), + dialogs=[llama3_1_builtin_tool_call_dialog()], + notes=textwrap.dedent( + """ + - Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model. + - The message body of the assistant response starts with a special tag <|python_tag|> + - As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call. + - The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha` + """ + ), + ), + UseCase( + title="Builtin Code Interpreter", + description="Here is an actual example of model responding with code", + dialogs=[ + [ + RawMessage(role="system", content="Environment: ipython"), + RawMessage( + role="user", + content="Write code to check if number is prime, use that to see if the number 7 is prime", + ), + ], + ], + notes=textwrap.dedent( + """ + - Model starts with <|python_tag|> and continues writing python code that it needs to be executed + - No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it. + """ + ), + ), + UseCase( + title="Built-in tools full interaction", + description="Here is a full interaction with the built-in tools including the tool response and the final assistant response.", + dialogs=[ + [ + RawMessage( + role="system", + content="Environment: ipython\nTools: brave_search, wolfram_alpha\n", + ), + RawMessage(role="user", content="What is the 100th decimal of pi?"), + RawMessage( + content="", + stop_reason=StopReason.end_of_message, + tool_calls=[ + ToolCall( + call_id="tool_call_id", + tool_name=BuiltinTool.wolfram_alpha, + arguments={"query": "100th decimal of pi"}, + ) + ], + ), + RawMessage( + role="tool", + content=wolfram_alpha_response(), + ), + ], + ], + notes=textwrap.dedent( + """ + - Note the `<|python_tag|>` in the assistant response. + - Role is `tool` for the wolfram alpha response that is passed back to the model. + - Final message from assistant has <|eot_id|> tag. + """ + ), + ), + "## Zero shot tool calling", + UseCase( + title="JSON based tool calling", + description=textwrap.dedent( + """ + Llama models can now output custom tool calls from a single message to allow easier tool calling. + The following prompts provide an example of how custom tools can be called from the output of the model. + It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog()], + notes=textwrap.dedent( + """ + - JSON format for providing tools needs name, description and parameters + - Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt + - Instructions for tools added as a user message + - Only single tool calls are supported as of now + """ + ), + ), + # FIXME: This is not working yet as expected + # UseCase( + # title="E2E tool call example", + # description=textwrap.dedent( + # """ + # Here is an example showing the whole multi-step turn by taking custom tool outputs and passing back to the model. + # """ + # ), + # dialogs=[ + # llama3_1_e2e_tool_call_dialog( + # tool_prompt_format=ToolPromptFormat.function_tag + # ) + # ], + # notes="", + # ), + "## Example of a user defined tool calling", + UseCase( + title="`` based tool calling", + description=textwrap.dedent( + """ + Here is an example of how you could also write custom instructions for model to do zero shot tool calling. + In this example, we define a custom tool calling format using the `` tag. + """ + ), + dialogs=[llama3_1_custom_tool_call_dialog(ToolPromptFormat.function_tag)], + notes=textwrap.dedent( + """ + - In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>` + - Instructions for tools added as a user message + """ + ), + ), + ] diff --git a/llama_stack/models/llama/prompt_format.py b/llama_stack/models/llama/prompt_format.py new file mode 100644 index 000000000..f42620d57 --- /dev/null +++ b/llama_stack/models/llama/prompt_format.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import json +import textwrap +from pathlib import Path +from typing import List + +from llama_models.datatypes import ( + RawContent, + RawMediaItem, + RawMessage, + RawTextItem, + StopReason, + ToolCall, + ToolPromptFormat, +) +from pydantic import BaseModel, Field + +from .llama3.interface import LLama31Interface +from .llama3.template_data import ( + system_message_builtin_code_only, + system_message_builtin_tools_only, + system_message_custom_tools_only, +) + + +class TextCompletionContent(BaseModel): + content: RawContent = "" + + +class UseCase(BaseModel): + title: str = "" + description: str = "" + dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list) + notes: str = "" + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json + + def md_format(self): + section = textwrap.dedent( + """ + ## {title} + + {description} + + {dialogs_text} + {notes} + + """ + ) + return section.lstrip() + + def dialogs_to_text(self, generator) -> str: + def _code_block(text): + return f"```\n{text}\n```" + + text = "" + for dialog in self.dialogs: + if isinstance(dialog, str): + text += dialog + text += "\n\n" + continue + + elif isinstance(dialog, TextCompletionContent): + input_tokens, output_tokens = generator.text_completion_raw( + dialog.content, + max_gen_len=64, + temperature=0.1, + top_p=0.95, + ) + else: + input_tokens, output_tokens = generator.chat_completion_raw( + dialog, + max_gen_len=512, + temperature=0.0, + top_p=0.95, + tool_prompt_format=self.tool_prompt_format, + ) + text += "##### Input Prompt Format\n" + + # FIXME: This is added to undo the hack in chat_formatter where + # vision tokens are replaced with 128256. + input_tokens = [generator.formatter.vision_token if t == 128256 else t for t in input_tokens] + + text += _code_block(generator.tokenizer.decode(input_tokens)) + # TODO: Figure out if "↵" needs to be added for newlines or end or some indication + text += "\n\n" + text += "##### Model Response Format\n" + text += _code_block(generator.tokenizer.decode(output_tokens)) + text += "\n\n" + + return text + + def to_text(self, generator): + section = self.md_format() + dialogs_text = self.dialogs_to_text(generator) + notes = f"##### Notes\n{self.notes}" if self.notes else "" + section = section.format( + title=self.title, + description=self.description, + dialogs_text=dialogs_text, + notes=notes, + ) + return section + + +def llama3_1_builtin_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_builtin_tools_only()) + messages += interface.user_message(content="Search the web for the latest price of 1oz gold?") + + return messages + + +def llama3_1_builtin_code_interpreter_dialog(tool_prompt_format=ToolPromptFormat.json): + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_builtin_code_only()) + messages += interface.user_message( + content="Write code to check if number is prime. Use it to verify if number 7 is prime" + ) + + return messages + + +def llama3_1_builtin_tool_call_with_image_dialog( + tool_prompt_format=ToolPromptFormat.json, +): + this_dir = Path(__file__).parent + with open(this_dir / "llama3/dog.jpg", "rb") as f: + img = f.read() + + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_builtin_tools_only()) + messages += interface.user_message(content=[RawMediaItem(data=img), RawTextItem(text="What is this dog breed?")]) + messages += interface.assistant_response_messages( + "Based on the description of the dog in the image, it appears to be a small breed dog, possibly a terrier mix", + StopReason.end_of_turn, + ) + messages += interface.user_message("Search the web for some food recommendations for the indentified breed") + return messages + + +def llama3_1_custom_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_custom_tools_only()) + messages += interface.user_message(content="Use tools to get latest trending songs") + return messages + + +def llama3_1_e2e_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json): + tool_response = json.dumps(["great song1", "awesome song2", "cool song3"]) + interface = LLama31Interface(tool_prompt_format) + + messages = interface.system_messages(**system_message_custom_tools_only()) + messages += interface.user_message(content="Use tools to get latest trending songs") + messages.append( + RawMessage( + role="assistant", + content="", + stop_reason=StopReason.end_of_message, + tool_calls=[ + ToolCall( + call_id="call_id", + tool_name="trending_songs", + arguments={"n": "10", "genre": "latest"}, + ) + ], + ), + ) + messages.append( + RawMessage( + role="assistant", + content=tool_response, + ) + ) + return messages + + +def llama3_2_user_assistant_conversation(): + return UseCase( + title="User and assistant conversation", + description="Here is a regular multi-turn user assistant conversation and how its formatted.", + dialogs=[ + [ + RawMessage(role="system", content="You are a helpful assistant"), + RawMessage(role="user", content="Who are you?"), + ] + ], + notes="This format is unchanged from Llama3.1", + ) diff --git a/llama_stack/models/llama/sku_list.py b/llama_stack/models/llama/sku_list.py new file mode 100644 index 000000000..6f4a5a885 --- /dev/null +++ b/llama_stack/models/llama/sku_list.py @@ -0,0 +1,1000 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Optional + +from .datatypes import ( + CheckpointQuantizationFormat, + CoreModelId, + Model, + SamplingParams, + TopPSamplingStrategy, +) + +LLAMA2_VOCAB_SIZE = 32000 +LLAMA3_VOCAB_SIZE = 128256 + + +def resolve_model(descriptor: str) -> Optional[Model]: + for m in all_registered_models(): + if descriptor in (m.descriptor(), m.huggingface_repo): + return m + return None + + +def all_registered_models() -> List[Model]: + return ( + llama2_family() + llama3_family() + llama3_1_family() + llama3_2_family() + llama3_3_family() + safety_models() + ) + + +def recommended_sampling_params() -> SamplingParams: + return SamplingParams( + strategy=TopPSamplingStrategy( + temperature=1.0, + top_p=0.9, + ) + ) + + +def llama2_family() -> List[Model]: + return [ + *llama2_base_models(), + *llama2_instruct_models(), + ] + + +def llama3_family() -> List[Model]: + return [ + *llama3_base_models(), + *llama3_instruct_models(), + ] + + +def llama3_1_family() -> List[Model]: + return [ + *llama3_1_base_models(), + *llama3_1_instruct_models(), + ] + + +def llama3_2_family() -> List[Model]: + return [ + *llama3_2_base_models(), + *llama3_2_instruct_models(), + ] + + +def llama3_3_family() -> List[Model]: + return [ + *llama3_3_instruct_models(), + ] + + +def llama2_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama2_7b, + description="Llama 2 7b model", + huggingface_repo="meta-llama/Llama-2-7b", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_13b, + description="Llama 2 13b model", + huggingface_repo="meta-llama/Llama-2-13b", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 5120, + "n_layers": 40, + "n_heads": 40, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_70b, + description="Llama 2 70b model", + huggingface_repo="meta-llama/Llama-2-70b", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_8b, + description="Llama 3 8b model", + huggingface_repo="meta-llama/Llama-3-8B", + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_70b, + description="Llama 3 70b model", + huggingface_repo="meta-llama/Llama-3-70B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_1_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_1_8b, + description="Llama 3.1 8b model", + huggingface_repo="meta-llama/Llama-3.1-8B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_1_70b, + description="Llama 3.1 70b model", + huggingface_repo="meta-llama/Llama-3.1-70B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b, + variant="bf16-mp8", + description="Llama 3.1 405b model (BF16 weights)", + huggingface_repo="meta-llama/Llama-3.1-405B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b, + description="Llama 3.1 405b model (FP8 quantized)", + huggingface_repo="meta-llama/Llama-3.1-405B-FP8", + quantization_format=CheckpointQuantizationFormat.fp8_mixed, + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b, + variant="bf16-mp16", + description="Llama 3.1 405b model (BF16 weights for mp16)", + huggingface_repo="meta-llama/Llama-3.1-405B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 16, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=16, + ), + ] + + +def llama3_2_base_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_2_1b, + description="Llama 3.2 1b model", + huggingface_repo="meta-llama/Llama-3.2-1B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 2048, + "n_layers": 16, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.5, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b, + description="Llama 3.2 3b model", + huggingface_repo="meta-llama/Llama-3.2-3B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 3072, + "n_layers": 28, + "n_heads": 24, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.0, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_11b_vision, + description="Llama 3.2 11b vision model", + huggingface_repo="meta-llama/Llama-3.2-11B-Vision", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 448, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 8, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_90b_vision, + description="Llama 3.2 90b vision model", + huggingface_repo="meta-llama/Llama-3.2-90B-Vision", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 20, + }, + pth_file_count=8, + ), + ] + + +def llama2_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama2_7b_chat, + description="Llama 2 7b chat model", + huggingface_repo="meta-llama/Llama-2-7b-chat", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_13b_chat, + description="Llama 2 13b chat model", + huggingface_repo="meta-llama/Llama-2-13b-chat", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 5120, + "n_layers": 40, + "n_heads": 40, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama2_70b_chat, + description="Llama 2 70b chat model", + huggingface_repo="meta-llama/Llama-2-70b-chat", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_8b_instruct, + description="Llama 3 8b instruct model", + huggingface_repo="meta-llama/Llama-3-8B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_70b_instruct, + description="Llama 3 70b instruct model", + huggingface_repo="meta-llama/Llama-3-70B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=8, + ), + ] + + +def llama3_1_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_1_8b_instruct, + description="Llama 3.1 8b instruct model", + huggingface_repo="meta-llama/Llama-3.1-8B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_1_70b_instruct, + description="Llama 3.1 70b instruct model", + huggingface_repo="meta-llama/Llama-3.1-70B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b_instruct, + variant="bf16-mp8", + description="Llama 3.1 405b instruct model (BF16 weights)", + huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b_instruct, + description="Llama 3.1 405b instruct model (FP8 quantized)", + huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8", + quantization_format=CheckpointQuantizationFormat.fp8_mixed, + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + Model( + core_model_id=CoreModelId.llama3_1_405b_instruct, + variant="bf16-mp16", + description="Llama 3.1 405b instruct model (BF16 weights for mp16)", + huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 16384, + "n_layers": 126, + "n_heads": 128, + "n_kv_heads": 16, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.2, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=16, + ), + ] + + +def arch_args_1b() -> dict: + return { + "dim": 2048, + "n_layers": 16, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.5, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + } + + +def arch_args_3b() -> dict: + return { + "dim": 3072, + "n_layers": 28, + "n_heads": 24, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.0, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + } + + +def llama3_2_quantized_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_2_1b_instruct, + variant="int4-qlora-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 1b INT4 quantized LoRA", + huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_1b(), + "quantization_args": { + "group_size": 256, + }, + "lora_args": { + "rank": 16, + "scale": 2.0, + }, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_1b_instruct, + variant="int4-spinquant-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 1b INT4 quantized SpinQuant", + huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_1b(), + "quantization_args": { + "group_size": 256, + }, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b_instruct, + variant="int4-qlora-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 3b INT4 quantized LoRA", + huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_3b(), + "quantization_args": { + "group_size": 256, + }, + "lora_args": { + "rank": 16, + "scale": 2.0, + }, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b_instruct, + variant="int4-spinquant-eo8", + quantization_format=CheckpointQuantizationFormat.int4, + description="Llama 3.2 3b INT4 quantized SpinQuant", + huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + **arch_args_3b(), + "quantization_args": { + "group_size": 256, + }, + }, + pth_file_count=1, + ), + ] + + +def llama3_2_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_2_1b_instruct, + description="Llama 3.2 1b instruct model", + huggingface_repo="meta-llama/Llama-3.2-1B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args=arch_args_1b(), + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_3b_instruct, + description="Llama 3.2 3b instruct model", + huggingface_repo="meta-llama/Llama-3.2-3B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args=arch_args_3b(), + pth_file_count=1, + ), + *llama3_2_quantized_models(), + Model( + core_model_id=CoreModelId.llama3_2_11b_vision_instruct, + description="Llama 3.2 11b vision instruct model", + huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 8, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama3_2_90b_vision_instruct, + description="Llama 3.2 90b vision instruct model", + huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 20, + }, + pth_file_count=8, + ), + ] + + +def llama3_3_instruct_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama3_3_70b_instruct, + description="Llama 3.3 70b instruct", + huggingface_repo="meta-llama/Llama-3.3-70B-Instruct", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 8192, + "n_layers": 80, + "n_heads": 64, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=8, + ), + ] + + +@lru_cache +def safety_models() -> List[Model]: + return [ + Model( + core_model_id=CoreModelId.llama_guard_3_11b_vision, + description="Llama Guard v3 11b vision system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vision_num_cross_attention_layers": 8, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_1b, + variant="int4", + description="Llama Guard v3 1b 'int4' quantized system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4", + quantization_format=CheckpointQuantizationFormat.int4, + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 2048, + "n_layers": 12, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "rope_freq_base": 500000.0, + "norm_eps": 1e-05, + "hidden_dim": 6400, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_1b, + description="Llama Guard v3 1b system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-1B", + recommended_sampling_params=recommended_sampling_params(), + arch_args={ + "dim": 2048, + "n_layers": 16, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA3_VOCAB_SIZE, + "ffn_dim_multiplier": 1.5, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": True, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_8b, + description="Llama Guard v3 8b system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-8B", + arch_args={ + "dim": 4096, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 32, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + "vocab_size": LLAMA3_VOCAB_SIZE, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_3_8b, + variant="int8", + description="Llama Guard v3 8b system safety model", + huggingface_repo="meta-llama/Llama-Guard-3-8B-INT8", + quantization_format=CheckpointQuantizationFormat.int8, + arch_args={ + "dim": 4096, + "ffn_dim_multiplier": 1.3, + "multiple_of": 1024, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 32, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + "vocab_size": LLAMA3_VOCAB_SIZE, + }, + pth_file_count=1, + ), + Model( + core_model_id=CoreModelId.llama_guard_2_8b, + description="Llama Guard v2 8b system safety model", + huggingface_repo="meta-llama/Llama-Guard-2-8B", + arch_args={ + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": LLAMA2_VOCAB_SIZE, + "ffn_dim_multiplier": 1.3, + "multiple_of": 256, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": False, + }, + pth_file_count=1, + ), + ] + + +@dataclass +class LlamaDownloadInfo: + folder: str + files: List[str] + pth_size: int + + +def llama_meta_net_info(model: Model) -> LlamaDownloadInfo: + """Information needed to download model from llamameta.net""" + + pth_count = model.pth_file_count + if model.core_model_id == CoreModelId.llama3_1_405b: + if pth_count == 16: + folder = "Llama-3.1-405B-MP16" + elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + folder = "Llama-3.1-405B" + else: + folder = "Llama-3.1-405B-MP8" + elif model.core_model_id == CoreModelId.llama3_1_405b_instruct: + if pth_count == 16: + folder = "Llama-3.1-405B-Instruct-MP16" + elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + folder = "Llama-3.1-405B-Instruct" + else: + folder = "Llama-3.1-405B-Instruct-MP8" + elif model.core_model_id == CoreModelId.llama_guard_3_8b: + if model.quantization_format == CheckpointQuantizationFormat.int8: + folder = "Llama-Guard-3-8B-INT8-HF" + else: + folder = "Llama-Guard-3-8B" + elif model.core_model_id == CoreModelId.llama_guard_2_8b: + folder = "llama-guard-2" + else: + folder = model.huggingface_repo.split("/")[-1] + if "Llama-2" in folder: + folder = folder.lower() + + files = ["checklist.chk"] + if ( + model.core_model_id == CoreModelId.llama_guard_3_8b + and model.quantization_format == CheckpointQuantizationFormat.int8 + ): + files.extend( + [ + "generation_config.json", + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "special_tokens_map.json", + "tokenizer.json", + "tokenizer_config.json", + "model.safetensors.index.json", + ] + ) + elif ( + model.core_model_id == CoreModelId.llama_guard_3_1b + and model.quantization_format == CheckpointQuantizationFormat.int4 + ): + files.extend( + [ + "llama_guard_3_1b_pruned_xnnpack.pte", + "example-prompt.txt", + "params.json", + "tokenizer.model", + ] + ) + else: + files.extend( + [ + "tokenizer.model", + "params.json", + ] + ) + if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + files.extend([f"fp8_scales_{i}.pt" for i in range(pth_count)]) + files.extend([f"consolidated.{i:02d}.pth" for i in range(pth_count)]) + + return LlamaDownloadInfo( + folder=folder, + files=files, + pth_size=llama_meta_pth_size(model), + ) + + +# Sadness because Cloudfront rejects our HEAD requests to find Content-Length +def llama_meta_pth_size(model: Model) -> int: + if model.core_model_id not in ( + CoreModelId.llama3_1_405b, + CoreModelId.llama3_1_405b_instruct, + ): + return 0 + + if model.pth_file_count == 16: + return 51268302389 + elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: + return 60903742309 + else: + return 101470976045 diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index b92f9dc0a..384582423 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -7,7 +7,6 @@ from typing import Any, List, Optional, Protocol from urllib.parse import urlparse -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from llama_stack.apis.benchmarks import Benchmark @@ -18,6 +17,7 @@ from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield from llama_stack.apis.tools import Tool from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.schema_utils import json_schema_type class ModelsProtocolPrivate(Protocol): diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8ba7885cd..fc597d0f7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx -from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition from pydantic import TypeAdapter from llama_stack.apis.agents import ( @@ -63,6 +62,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index 4e3951ad3..b802937b6 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -8,7 +8,6 @@ import tempfile from typing import AsyncIterator, List, Optional, Union import pytest -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, @@ -41,6 +40,7 @@ from llama_stack.apis.tools import ( ToolInvocationResult, ) from llama_stack.apis.vector_io import QueryChunksResponse +from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index e60c3b1be..2d2ec5c8f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -23,20 +23,13 @@ from fairscale.nn.model_parallel.initialize import ( initialize_model_parallel, model_parallel_is_initialized, ) -from llama_models.datatypes import ( - GreedySamplingStrategy, - SamplingParams, - TopPSamplingStrategy, -) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, LLMInput -from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) -from llama_models.sku_list import resolve_model from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from pydantic import BaseModel @@ -47,6 +40,13 @@ from llama_stack.apis.inference import ( ResponseFormatType, ) from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.models.llama.datatypes import ( + GreedySamplingStrategy, + Model, + SamplingParams, + TopPSamplingStrategy, +) +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 61f0ee3f4..c79f97def 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -8,14 +8,6 @@ import asyncio import logging from typing import AsyncGenerator, List, Optional, Union -from llama_models.llama3.api.datatypes import ( - SamplingParams, - StopReason, - ToolDefinition, - ToolPromptFormat, -) -from llama_models.sku_list import resolve_model - from llama_stack.apis.common.content_types import ( TextDelta, ToolCallDelta, @@ -41,6 +33,13 @@ from llama_stack.apis.inference import ( ToolConfig, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.datatypes import ( + SamplingParams, + StopReason, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index ef133274c..64f94a69d 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -10,10 +10,10 @@ from functools import partial from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model +from llama_stack.models.llama.datatypes import Model +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index 9be35ae70..a2dc00916 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -14,14 +14,14 @@ from typing import Any, Dict, List, Optional import torch from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region -from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock -from llama_models.sku_list import resolve_model from torch import Tensor, nn from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType +from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat +from llama_stack.models.llama.sku_list import resolve_model from ..config import MetaReferenceQuantizedInferenceConfig diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py index de2bae265..51ef2d273 100644 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -4,10 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, field_validator from llama_stack.providers.utils.inference import supported_inference_models +from llama_stack.schema_utils import json_schema_type @json_schema_type diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index e75a9aac3..5536ea3a5 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -11,7 +11,6 @@ from typing import AsyncGenerator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import resolve_model from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams as VLLMSamplingParams @@ -35,6 +34,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 735af8c79..98e16f9d7 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -13,8 +13,6 @@ from typing import Any, Callable, Dict import torch -from llama_models.datatypes import Model -from llama_models.sku_list import resolve_model from pydantic import BaseModel from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages from torchtune.models.llama3 import llama3_tokenizer @@ -24,6 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.modules.transforms import Transform from llama_stack.apis.post_training import DatasetFormat +from llama_stack.models.llama.datatypes import Model +from llama_stack.models.llama.sku_list import resolve_model class ModelConfig(BaseModel): diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index ba11736d6..c77d9305f 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -6,8 +6,6 @@ from datetime import datetime from typing import Any, Dict, Optional -from llama_models.schema_utils import webmethod - from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( @@ -27,6 +25,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( LoraFinetuningSingleDevice, ) +from llama_stack.schema_utils import webmethod class TorchtunePostTrainingImpl: diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index ef379aff2..4ab59fec4 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -14,7 +14,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch -from llama_models.sku_list import resolve_model from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler @@ -46,6 +45,7 @@ from llama_stack.apis.post_training import ( ) from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.inline.post_training.common.validator import ( validate_input_dataset_schema, ) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 32d6d5100..af0987fa8 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -8,9 +8,6 @@ import re from string import Template from typing import Any, Dict, List, Optional -from llama_models.datatypes import CoreModelId -from llama_models.llama3.api.datatypes import Role - from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.inference import ( ChatCompletionResponseEventType, @@ -26,6 +23,7 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.distribution.datatypes import Api +from llama_stack.models.llama.datatypes import CoreModelId, Role from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, diff --git a/llama_stack/providers/inline/vector_io/faiss/config.py b/llama_stack/providers/inline/vector_io/faiss/config.py index ae859842d..9eae9ed67 100644 --- a/llama_stack/providers/inline/vector_io/faiss/config.py +++ b/llama_stack/providers/inline/vector_io/faiss/config.py @@ -6,13 +6,13 @@ from typing import Any, Dict -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel from llama_stack.providers.utils.kvstore.config import ( KVStoreConfig, SqliteKVStoreConfig, ) +from llama_stack.schema_utils import json_schema_type @json_schema_type diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 917ac7a25..e896f0597 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -8,7 +8,6 @@ import json from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -28,6 +27,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.inference.model_registry import ( diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 2158fc5b4..1ce267e8d 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -7,9 +7,7 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import TopKSamplingStrategy from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent @@ -28,6 +26,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import CoreModelId, TopKSamplingStrategy from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py index 6eb4dffec..81682c980 100644 --- a/llama_stack/providers/remote/inference/cerebras/config.py +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -7,9 +7,10 @@ import os from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, SecretStr +from llama_stack.schema_utils import json_schema_type + DEFAULT_BASE_URL = "https://api.cerebras.ai" diff --git a/llama_stack/providers/remote/inference/databricks/config.py b/llama_stack/providers/remote/inference/databricks/config.py index ae2b056ea..6aaf7e594 100644 --- a/llama_stack/providers/remote/inference/databricks/config.py +++ b/llama_stack/providers/remote/inference/databricks/config.py @@ -5,9 +5,10 @@ # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class DatabricksImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index d56be1465..3d306e61f 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -6,7 +6,6 @@ from typing import AsyncGenerator, List, Optional -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI @@ -25,6 +24,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index aa4c2d1de..005dfe829 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -6,9 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, SecretStr +from llama_stack.schema_utils import json_schema_type + @json_schema_type class FireworksImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 7e8f85313..acf37b248 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -7,7 +7,6 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -30,6 +29,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/groq/config.py b/llama_stack/providers/remote/inference/groq/config.py index 7c5023410..cb2619437 100644 --- a/llama_stack/providers/remote/inference/groq/config.py +++ b/llama_stack/providers/remote/inference/groq/config.py @@ -6,9 +6,10 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class GroqConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 59ec8b0d2..441b6af5c 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -9,9 +9,6 @@ from typing import AsyncIterator, List, Optional, Union import groq from groq import Groq -from llama_models.datatypes import SamplingParams -from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat -from llama_models.sku_list import CoreModelId from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -29,6 +26,8 @@ from llama_stack.apis.inference import ( ToolConfig, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.sku_list import CoreModelId from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 2445c1b39..f1138e789 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -24,7 +24,6 @@ from groq.types.chat.chat_completion_user_message_param import ( ) from groq.types.chat.completion_create_params import CompletionCreateParams from groq.types.shared.function_definition import FunctionDefinition -from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_stack.apis.common.content_types import ( TextDelta, @@ -44,6 +43,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import ToolParamDefinition from llama_stack.providers.utils.inference.openai_compat import ( UnparseableToolCall, convert_tool_call, diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index 9bf5eb469..abd34b498 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -7,9 +7,10 @@ import os from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, SecretStr +from llama_stack.schema_utils import json_schema_type + @json_schema_type class NVIDIAConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 82343513f..0c5b7c454 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,9 +7,6 @@ import warnings from typing import AsyncIterator, List, Optional, Union -from llama_models.datatypes import SamplingParams -from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat -from llama_models.sku_list import CoreModelId from openai import APIConnectionError, AsyncOpenAI from llama_stack.apis.inference import ( @@ -28,6 +25,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) +from llama_stack.models.llama.datatypes import CoreModelId, SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index c757c562c..9799eedcc 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -8,17 +8,6 @@ import json import warnings from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union -from llama_models.datatypes import ( - GreedySamplingStrategy, - TopKSamplingStrategy, - TopPSamplingStrategy, -) -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - StopReason, - ToolCall, - ToolDefinition, -) from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, @@ -87,6 +76,15 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + GreedySamplingStrategy, + StopReason, + ToolCall, + ToolDefinition, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 1c12d0d91..f524c0734 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -8,7 +8,6 @@ import logging from typing import AsyncGenerator, List, Optional, Union import httpx -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient @@ -34,6 +33,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/runpod/config.py b/llama_stack/providers/remote/inference/runpod/config.py index 1a9582052..e59cfe59b 100644 --- a/llama_stack/providers/remote/inference/runpod/config.py +++ b/llama_stack/providers/remote/inference/runpod/config.py @@ -6,9 +6,10 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class RunpodImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index a3c615418..1abb17336 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -6,11 +6,11 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.models.llama.datatypes import Message # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py index 1798841df..a30c29b74 100644 --- a/llama_stack/providers/remote/inference/sambanova/config.py +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -6,9 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class SambaNovaImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 3546ee977..b906e0dcb 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -7,12 +7,6 @@ import json from typing import AsyncGenerator -from llama_models.datatypes import ( - CoreModelId, - GreedySamplingStrategy, - TopKSamplingStrategy, - TopPSamplingStrategy, -) from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI @@ -23,6 +17,12 @@ from llama_stack.apis.common.content_types import ( TextContentItem, ) from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.models.llama.datatypes import ( + CoreModelId, + GreedySamplingStrategy, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/tgi/config.py b/llama_stack/providers/remote/inference/tgi/config.py index 4f690dec6..6ad663662 100644 --- a/llama_stack/providers/remote/inference/tgi/config.py +++ b/llama_stack/providers/remote/inference/tgi/config.py @@ -6,9 +6,10 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, SecretStr +from llama_stack.schema_utils import json_schema_type + @json_schema_type class TGIImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 72eaa6c31..1909e01f8 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -11,7 +11,6 @@ from typing import AsyncGenerator, List, Optional from huggingface_hub import AsyncInferenceClient, HfApi from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import all_registered_models from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( @@ -31,6 +30,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py index a56cb5bb8..fda3b8f43 100644 --- a/llama_stack/providers/remote/inference/together/config.py +++ b/llama_stack/providers/remote/inference/together/config.py @@ -6,9 +6,10 @@ from typing import Any, Dict, Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field, SecretStr +from llama_stack.schema_utils import json_schema_type + @json_schema_type class TogetherImplConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 916e64ad4..054501da8 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -6,7 +6,6 @@ from typing import AsyncGenerator, List, Optional, Union -from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from together import Together @@ -29,6 +28,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index a3a4c6930..c75cc8926 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -6,9 +6,10 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class VLLMInferenceAdapterConfig(BaseModel): diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8f9cf68a8..b22284302 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -7,10 +7,9 @@ import json import logging from typing import AsyncGenerator, List, Optional, Union -from llama_models.llama3.api import StopReason, ToolCall +from llama_models.datatypes import StopReason, ToolCall from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import all_registered_models from openai import OpenAI from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus @@ -37,6 +36,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/safety/bedrock/config.py b/llama_stack/providers/remote/safety/bedrock/config.py index 8c61decf3..1ca8d95cb 100644 --- a/llama_stack/providers/remote/safety/bedrock/config.py +++ b/llama_stack/providers/remote/safety/bedrock/config.py @@ -5,9 +5,8 @@ # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type - from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig +from llama_stack.schema_utils import json_schema_type @json_schema_type diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 564f76088..8ef9f5705 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional import requests -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -18,6 +17,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import BraveSearchToolConfig diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py index 2a64d7c67..7811de1ca 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/config.py +++ b/llama_stack/providers/remote/vector_io/pgvector/config.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class PGVectorVectorIOConfig(BaseModel): diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index 613cfa6e4..f212882d8 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -6,9 +6,10 @@ from typing import Optional -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +from llama_stack.schema_utils import json_schema_type + @json_schema_type class QdrantVectorIOConfig(BaseModel): diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 45b276cc3..2e7bd537f 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -7,8 +7,6 @@ import os import pytest -from llama_models.datatypes import SamplingParams, TopPSamplingStrategy -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, @@ -25,6 +23,7 @@ from llama_stack.apis.agents import ( ) from llama_stack.apis.inference import CompletionMessage, UserMessage from llama_stack.apis.safety import ViolationLevel +from llama_stack.models.llama.datatypes import BuiltinTool, SamplingParams, TopPSamplingStrategy from llama_stack.providers.datatypes import Api # How to run this test: diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index 3eba991c1..34725e957 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -23,8 +23,6 @@ from groq.types.chat.chat_completion_message_tool_call import ( Function, ) from groq.types.shared.function_definition import FunctionDefinition -from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy -from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ( @@ -38,6 +36,7 @@ from llama_stack.apis.inference import ( ToolDefinition, UserMessage, ) +from llama_stack.models.llama.datatypes import GreedySamplingStrategy, ToolParamDefinition, TopPSamplingStrategy from llama_stack.providers.remote.inference.groq.groq_utils import ( convert_chat_completion_request, convert_chat_completion_response, diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index c087c5df2..323c6cb6a 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -6,19 +6,18 @@ import unittest -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - ToolDefinition, - ToolParamDefinition, - ToolPromptFormat, -) - from llama_stack.apis.inference import ( ChatCompletionRequest, SystemMessage, ToolConfig, UserMessage, ) +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 14ed2fc4b..f25b95004 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -6,14 +6,6 @@ import pytest -from llama_models.llama3.api.datatypes import ( - SamplingParams, - StopReason, - ToolCall, - ToolDefinition, - ToolParamDefinition, - ToolPromptFormat, -) from pydantic import BaseModel, ValidationError from llama_stack.apis.common.content_types import ToolCallParseStatus @@ -30,6 +22,14 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.models import ListModelsResponse, Model +from llama_stack.models.llama.datatypes import ( + SamplingParams, + StopReason, + ToolCall, + ToolDefinition, + ToolParamDefinition, + ToolPromptFormat, +) from .utils import group_chunks diff --git a/llama_stack/providers/tests/report.py b/llama_stack/providers/tests/report.py index 3901dc2e3..febd13045 100644 --- a/llama_stack/providers/tests/report.py +++ b/llama_stack/providers/tests/report.py @@ -9,11 +9,12 @@ from collections import defaultdict from pathlib import Path import pytest -from llama_models.datatypes import CoreModelId -from llama_models.sku_list import all_registered_models from pytest import ExitCode from pytest_html.basereport import _process_outcome +from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_list import all_registered_models + INFERENCE_APIS = ["chat_completion"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] SUPPORTED_MODELS = { diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index 64fe30f55..cab3725da 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -6,8 +6,8 @@ from typing import List -from llama_models.datatypes import * # noqa: F403 -from llama_models.sku_list import all_registered_models +from llama_stack.models.llama.datatypes import * # noqa: F403 +from llama_stack.models.llama.sku_list import all_registered_models def is_supported_safety_model(model: Model) -> bool: diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 9345da949..c5f6cd6b5 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -7,9 +7,8 @@ from collections import namedtuple from typing import List, Optional -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 33f0f4e22..da8e3ce2d 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -7,14 +7,7 @@ import json import logging from typing import AsyncGenerator, Dict, List, Optional, Union -from llama_models.datatypes import ( - GreedySamplingStrategy, - SamplingParams, - TopKSamplingStrategy, - TopPSamplingStrategy, -) from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import StopReason, ToolCall from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel @@ -37,6 +30,14 @@ from llama_stack.apis.inference import ( Message, TokenLogProbs, ) +from llama_stack.models.llama.datatypes import ( + GreedySamplingStrategy, + SamplingParams, + StopReason, + ToolCall, + TopKSamplingStrategy, + TopPSamplingStrategy, +) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 15149e059..b7945dee7 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -13,25 +13,7 @@ import re from typing import List, Optional, Tuple, Union import httpx -from llama_models.datatypes import ModelFamily, is_multimodal from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import ( - RawContent, - RawContentItem, - RawMediaItem, - RawMessage, - RawTextItem, - Role, - ToolPromptFormat, -) -from llama_models.llama3.prompt_templates import ( - BuiltinToolGenerator, - FunctionTagCustomToolGenerator, - JsonCustomToolGenerator, - PythonListCustomToolGenerator, - SystemDefaultGenerator, -) -from llama_models.sku_list import resolve_model from PIL import Image as PIL_Image from llama_stack.apis.common.content_types import ( @@ -51,6 +33,25 @@ from llama_stack.apis.inference import ( ToolChoice, UserMessage, ) +from llama_stack.models.llama.datatypes import ( + ModelFamily, + RawContent, + RawContentItem, + RawMediaItem, + RawMessage, + RawTextItem, + Role, + ToolPromptFormat, + is_multimodal, +) +from llama_stack.models.llama.llama3.prompt_templates import ( + BuiltinToolGenerator, + FunctionTagCustomToolGenerator, + JsonCustomToolGenerator, + PythonListCustomToolGenerator, + SystemDefaultGenerator, +) +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models log = logging.getLogger(__name__) diff --git a/llama_stack/providers/utils/kvstore/sqlite/config.py b/llama_stack/providers/utils/kvstore/sqlite/config.py index a616c90d0..6a8b0a7cf 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/config.py +++ b/llama_stack/providers/utils/kvstore/sqlite/config.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field +from llama_stack.schema_utils import json_schema_type + @json_schema_type class SqliteControlPlaneConfig(BaseModel): diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/providers/utils/telemetry/trace_protocol.py index 80c58a2c7..924274c42 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/providers/utils/telemetry/trace_protocol.py @@ -9,9 +9,10 @@ import inspect from functools import wraps from typing import Any, AsyncGenerator, Callable, Type, TypeVar -from llama_models.llama3.api.datatypes import Primitive from pydantic import BaseModel +from llama_stack.models.llama.datatypes import Primitive + T = TypeVar("T") diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py new file mode 100644 index 000000000..56b9e5e4c --- /dev/null +++ b/llama_stack/schema_utils.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any, Callable, List, Optional, TypeVar + +from .strong_typing.schema import json_schema_type, register_schema # noqa: F401 + +T = TypeVar("T") + + +@dataclass +class WebMethod: + route: Optional[str] = None + public: bool = False + request_examples: Optional[List[Any]] = None + response_examples: Optional[List[Any]] = None + method: Optional[str] = None + + +def webmethod( + route: Optional[str] = None, + method: Optional[str] = None, + public: Optional[bool] = False, + request_examples: Optional[List[Any]] = None, + response_examples: Optional[List[Any]] = None, +) -> Callable[[T], T]: + """ + Decorator that supplies additional metadata to an endpoint operation function. + + :param route: The URL path pattern associated with this operation which path parameters are substituted into. + :param public: True if the operation can be invoked without prior authentication. + :param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON. + :param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON. + """ + + def wrap(cls: T) -> T: + cls.__webmethod__ = WebMethod( + route=route, + method=method, + public=public or False, + request_examples=request_examples, + response_examples=response_examples, + ) + return cls + + return wrap diff --git a/llama_stack/scripts/generate_prompt_format.py b/llama_stack/scripts/generate_prompt_format.py new file mode 100644 index 000000000..ecdde900f --- /dev/null +++ b/llama_stack/scripts/generate_prompt_format.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import importlib +from pathlib import Path +from typing import Optional + +import fire + +# from llama_stack.models.llama.datatypes import * # noqa: F403 +from llama_models.llama3.reference_impl.generation import Llama + +THIS_DIR = Path(__file__).parent.resolve() + + +def run_main( + ckpt_dir: str, + module_name: str, + output_path: str, + model_parallel_size: Optional[int] = None, +): + module = importlib.import_module(module_name) + assert hasattr(module, "usecases"), f"Module {module_name} missing usecases function" + tokenizer_path = str(THIS_DIR.parent / "llama3/api/tokenizer.model") + generator = Llama.build( + ckpt_dir=ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=512, + max_batch_size=1, + model_parallel_size=model_parallel_size, + ) + + use_cases = module.usecases() + text = "" + for u in use_cases: + if isinstance(u, str): + use_case_text = f"\n{u}\n" + else: + use_case_text = u.to_text(generator) + + text += use_case_text + print(use_case_text) + + text += "Thank You!\n" + + with open(output_path, "w") as f: + f.write(text) + + +def main(): + fire.Fire(run_main) + + +if __name__ == "__main__": + main() diff --git a/llama_stack/strong_typing/__init__.py b/llama_stack/strong_typing/__init__.py new file mode 100644 index 000000000..d832dcf6f --- /dev/null +++ b/llama_stack/strong_typing/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +Provides auxiliary services for working with Python type annotations, converting typed data to and from JSON, +and generating a JSON schema for a complex type. +""" + +__version__ = "0.3.4" +__author__ = "Levente Hunyadi" +__copyright__ = "Copyright 2021-2024, Levente Hunyadi" +__license__ = "MIT" +__maintainer__ = "Levente Hunyadi" +__status__ = "Production" diff --git a/llama_stack/strong_typing/auxiliary.py b/llama_stack/strong_typing/auxiliary.py new file mode 100644 index 000000000..fd183da18 --- /dev/null +++ b/llama_stack/strong_typing/auxiliary.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +import dataclasses +import sys +from dataclasses import is_dataclass +from typing import Callable, Dict, Optional, Type, TypeVar, Union, overload + +if sys.version_info >= (3, 9): + from typing import Annotated as Annotated +else: + from typing_extensions import Annotated as Annotated + +if sys.version_info >= (3, 10): + from typing import TypeAlias as TypeAlias +else: + from typing_extensions import TypeAlias as TypeAlias + +if sys.version_info >= (3, 11): + from typing import dataclass_transform as dataclass_transform +else: + from typing_extensions import dataclass_transform as dataclass_transform + +T = TypeVar("T") + + +def _compact_dataclass_repr(obj: object) -> str: + """ + Compact data-class representation where positional arguments are used instead of keyword arguments. + + :param obj: A data-class object. + :returns: A string that matches the pattern `Class(arg1, arg2, ...)`. + """ + + if is_dataclass(obj): + arglist = ", ".join(repr(getattr(obj, field.name)) for field in dataclasses.fields(obj)) + return f"{obj.__class__.__name__}({arglist})" + else: + return obj.__class__.__name__ + + +class CompactDataClass: + "A data class whose repr() uses positional rather than keyword arguments." + + def __repr__(self) -> str: + return _compact_dataclass_repr(self) + + +@overload +def typeannotation(cls: Type[T], /) -> Type[T]: ... + + +@overload +def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[Type[T]], Type[T]]: ... + + +@dataclass_transform(eq_default=True, order_default=False) +def typeannotation( + cls: Optional[Type[T]] = None, *, eq: bool = True, order: bool = False +) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: + """ + Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. + + :param cls: The data-class type to transform into a type annotation. + :param eq: Whether to generate functions to support equality comparison. + :param order: Whether to generate functions to support ordering. + :returns: A data-class type, or a wrapper for data-class types. + """ + + def wrap(cls: Type[T]) -> Type[T]: + setattr(cls, "__repr__", _compact_dataclass_repr) + if not dataclasses.is_dataclass(cls): + cls = dataclasses.dataclass( # type: ignore[call-overload] + cls, + init=True, + repr=False, + eq=eq, + order=order, + unsafe_hash=False, + frozen=True, + ) + return cls + + # see if decorator is used as @typeannotation or @typeannotation() + if cls is None: + # called with parentheses + return wrap + else: + # called without parentheses + return wrap(cls) + + +@typeannotation +class Alias: + "Alternative name of a property, typically used in JSON serialization." + + name: str + + +@typeannotation +class Signed: + "Signedness of an integer type." + + is_signed: bool + + +@typeannotation +class Storage: + "Number of bytes the binary representation of an integer type takes, e.g. 4 bytes for an int32." + + bytes: int + + +@typeannotation +class IntegerRange: + "Minimum and maximum value of an integer. The range is inclusive." + + minimum: int + maximum: int + + +@typeannotation +class Precision: + "Precision of a floating-point value." + + significant_digits: int + decimal_digits: int = 0 + + @property + def integer_digits(self) -> int: + return self.significant_digits - self.decimal_digits + + +@typeannotation +class TimePrecision: + """ + Precision of a timestamp or time interval. + + :param decimal_digits: Number of fractional digits retained in the sub-seconds field for a timestamp. + """ + + decimal_digits: int = 0 + + +@typeannotation +class Length: + "Exact length of a string." + + value: int + + +@typeannotation +class MinLength: + "Minimum length of a string." + + value: int + + +@typeannotation +class MaxLength: + "Maximum length of a string." + + value: int + + +@typeannotation +class SpecialConversion: + "Indicates that the annotated type is subject to custom conversion rules." + + +int8: TypeAlias = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)] +int16: TypeAlias = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)] +int32: TypeAlias = Annotated[ + int, + Signed(True), + Storage(4), + IntegerRange(-2147483648, 2147483647), +] +int64: TypeAlias = Annotated[ + int, + Signed(True), + Storage(8), + IntegerRange(-9223372036854775808, 9223372036854775807), +] + +uint8: TypeAlias = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)] +uint16: TypeAlias = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)] +uint32: TypeAlias = Annotated[ + int, + Signed(False), + Storage(4), + IntegerRange(0, 4294967295), +] +uint64: TypeAlias = Annotated[ + int, + Signed(False), + Storage(8), + IntegerRange(0, 18446744073709551615), +] + +float32: TypeAlias = Annotated[float, Storage(4)] +float64: TypeAlias = Annotated[float, Storage(8)] + +# maps globals of type Annotated[T, ...] defined in this module to their string names +_auxiliary_types: Dict[object, str] = {} +module = sys.modules[__name__] +for var in dir(module): + typ = getattr(module, var) + if getattr(typ, "__metadata__", None) is not None: + # type is Annotated[T, ...] + _auxiliary_types[typ] = var + + +def get_auxiliary_format(data_type: object) -> Optional[str]: + "Returns the JSON format string corresponding to an auxiliary type." + + return _auxiliary_types.get(data_type) diff --git a/llama_stack/strong_typing/classdef.py b/llama_stack/strong_typing/classdef.py new file mode 100644 index 000000000..d2d8688e4 --- /dev/null +++ b/llama_stack/strong_typing/classdef.py @@ -0,0 +1,440 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import copy +import dataclasses +import datetime +import decimal +import enum +import ipaddress +import math +import re +import sys +import types +import typing +import uuid +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union + +from .auxiliary import ( + Alias, + Annotated, + MaxLength, + Precision, + float32, + float64, + int16, + int32, + int64, +) +from .core import JsonType, Schema +from .docstring import Docstring, DocstringParam +from .inspection import TypeLike +from .serialization import json_to_object, object_to_json + +T = TypeVar("T") + + +@dataclass +class JsonSchemaNode: + title: Optional[str] + description: Optional[str] + + +@dataclass +class JsonSchemaType(JsonSchemaNode): + type: str + format: Optional[str] + + +@dataclass +class JsonSchemaBoolean(JsonSchemaType): + type: Literal["boolean"] + const: Optional[bool] + default: Optional[bool] + examples: Optional[List[bool]] + + +@dataclass +class JsonSchemaInteger(JsonSchemaType): + type: Literal["integer"] + const: Optional[int] + default: Optional[int] + examples: Optional[List[int]] + enum: Optional[List[int]] + minimum: Optional[int] + maximum: Optional[int] + + +@dataclass +class JsonSchemaNumber(JsonSchemaType): + type: Literal["number"] + const: Optional[float] + default: Optional[float] + examples: Optional[List[float]] + minimum: Optional[float] + maximum: Optional[float] + exclusiveMinimum: Optional[float] + exclusiveMaximum: Optional[float] + multipleOf: Optional[float] + + +@dataclass +class JsonSchemaString(JsonSchemaType): + type: Literal["string"] + const: Optional[str] + default: Optional[str] + examples: Optional[List[str]] + enum: Optional[List[str]] + minLength: Optional[int] + maxLength: Optional[int] + + +@dataclass +class JsonSchemaArray(JsonSchemaType): + type: Literal["array"] + items: "JsonSchemaAny" + + +@dataclass +class JsonSchemaObject(JsonSchemaType): + type: Literal["object"] + properties: Optional[Dict[str, "JsonSchemaAny"]] + additionalProperties: Optional[bool] + required: Optional[List[str]] + + +@dataclass +class JsonSchemaRef(JsonSchemaNode): + ref: Annotated[str, Alias("$ref")] + + +@dataclass +class JsonSchemaAllOf(JsonSchemaNode): + allOf: List["JsonSchemaAny"] + + +@dataclass +class JsonSchemaAnyOf(JsonSchemaNode): + anyOf: List["JsonSchemaAny"] + + +@dataclass +class Discriminator: + propertyName: str + mapping: Dict[str, str] + + +@dataclass +class JsonSchemaOneOf(JsonSchemaNode): + oneOf: List["JsonSchemaAny"] + discriminator: Optional[Discriminator] + + +JsonSchemaAny = Union[ + JsonSchemaRef, + JsonSchemaBoolean, + JsonSchemaInteger, + JsonSchemaNumber, + JsonSchemaString, + JsonSchemaArray, + JsonSchemaObject, + JsonSchemaOneOf, +] + + +@dataclass +class JsonSchemaTopLevelObject(JsonSchemaObject): + schema: Annotated[str, Alias("$schema")] + definitions: Optional[Dict[str, JsonSchemaAny]] + + +def integer_range_to_type(min_value: float, max_value: float) -> type: + if min_value >= -(2**15) and max_value < 2**15: + return int16 + elif min_value >= -(2**31) and max_value < 2**31: + return int32 + else: + return int64 + + +def enum_safe_name(name: str) -> str: + name = re.sub(r"\W", "_", name) + is_dunder = name.startswith("__") + is_sunder = name.startswith("_") and name.endswith("_") + if is_dunder or is_sunder: # provide an alternative for dunder and sunder names + name = f"v{name}" + return name + + +def enum_values_to_type( + module: types.ModuleType, + name: str, + values: Dict[str, Any], + title: Optional[str] = None, + description: Optional[str] = None, +) -> Type[enum.Enum]: + enum_class: Type[enum.Enum] = enum.Enum(name, values) # type: ignore + + # assign the newly created type to the same module where the defining class is + enum_class.__module__ = module.__name__ + enum_class.__doc__ = str(Docstring(short_description=title, long_description=description)) + setattr(module, name, enum_class) + + return enum.unique(enum_class) + + +def schema_to_type(schema: Schema, *, module: types.ModuleType, class_name: str) -> TypeLike: + """ + Creates a Python type from a JSON schema. + + :param schema: The JSON schema that the types would correspond to. + :param module: The module in which to create the new types. + :param class_name: The name assigned to the top-level class. + """ + + top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema)) + if top_node.definitions is not None: + for type_name, type_node in top_node.definitions.items(): + type_def = node_to_typedef(module, type_name, type_node) + if type_def.default is not dataclasses.MISSING: + raise TypeError("disallowed: `default` for top-level type definitions") + + setattr(type_def.type, "__module__", module.__name__) + setattr(module, type_name, type_def.type) + + return node_to_typedef(module, class_name, top_node).type + + +@dataclass +class TypeDef: + type: TypeLike + default: Any = dataclasses.MISSING + + +def json_to_value(target_type: TypeLike, data: JsonType) -> Any: + if data is not None: + return json_to_object(target_type, data) + else: + return dataclasses.MISSING + + +def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode) -> TypeDef: + if isinstance(node, JsonSchemaRef): + match_obj = re.match(r"^#/definitions/(\w+)$", node.ref) + if not match_obj: + raise ValueError(f"invalid reference: {node.ref}") + + type_name = match_obj.group(1) + return TypeDef(getattr(module, type_name), dataclasses.MISSING) + + elif isinstance(node, JsonSchemaBoolean): + if node.const is not None: + return TypeDef(Literal[node.const], dataclasses.MISSING) + + default = json_to_value(bool, node.default) + return TypeDef(bool, default) + + elif isinstance(node, JsonSchemaInteger): + if node.const is not None: + return TypeDef(Literal[node.const], dataclasses.MISSING) + + integer_type: TypeLike + if node.format == "int16": + integer_type = int16 + elif node.format == "int32": + integer_type = int32 + elif node.format == "int64": + integer_type = int64 + else: + if node.enum is not None: + integer_type = integer_range_to_type(min(node.enum), max(node.enum)) + elif node.minimum is not None and node.maximum is not None: + integer_type = integer_range_to_type(node.minimum, node.maximum) + else: + integer_type = int + + default = json_to_value(integer_type, node.default) + return TypeDef(integer_type, default) + + elif isinstance(node, JsonSchemaNumber): + if node.const is not None: + return TypeDef(Literal[node.const], dataclasses.MISSING) + + number_type: TypeLike + if node.format == "float32": + number_type = float32 + elif node.format == "float64": + number_type = float64 + else: + if ( + node.exclusiveMinimum is not None + and node.exclusiveMaximum is not None + and node.exclusiveMinimum == -node.exclusiveMaximum + ): + integer_digits = round(math.log10(node.exclusiveMaximum)) + else: + integer_digits = None + + if node.multipleOf is not None: + decimal_digits = -round(math.log10(node.multipleOf)) + else: + decimal_digits = None + + if integer_digits is not None and decimal_digits is not None: + number_type = Annotated[ + decimal.Decimal, + Precision(integer_digits + decimal_digits, decimal_digits), + ] + else: + number_type = float + + default = json_to_value(number_type, node.default) + return TypeDef(number_type, default) + + elif isinstance(node, JsonSchemaString): + if node.const is not None: + return TypeDef(Literal[node.const], dataclasses.MISSING) + + string_type: TypeLike + if node.format == "date-time": + string_type = datetime.datetime + elif node.format == "uuid": + string_type = uuid.UUID + elif node.format == "ipv4": + string_type = ipaddress.IPv4Address + elif node.format == "ipv6": + string_type = ipaddress.IPv6Address + + elif node.enum is not None: + string_type = enum_values_to_type( + module, + context, + {enum_safe_name(e): e for e in node.enum}, + title=node.title, + description=node.description, + ) + + elif node.maxLength is not None: + string_type = Annotated[str, MaxLength(node.maxLength)] + else: + string_type = str + + default = json_to_value(string_type, node.default) + return TypeDef(string_type, default) + + elif isinstance(node, JsonSchemaArray): + type_def = node_to_typedef(module, context, node.items) + if type_def.default is not dataclasses.MISSING: + raise TypeError("disallowed: `default` for array element type") + list_type = List[(type_def.type,)] # type: ignore + return TypeDef(list_type, dataclasses.MISSING) + + elif isinstance(node, JsonSchemaObject): + if node.properties is None: + return TypeDef(JsonType, dataclasses.MISSING) + + if node.additionalProperties is None or node.additionalProperties is not False: + raise TypeError("expected: `additionalProperties` equals `false`") + + required = node.required if node.required is not None else [] + + class_name = context + + fields: List[Tuple[str, Any, dataclasses.Field]] = [] + params: Dict[str, DocstringParam] = {} + for prop_name, prop_node in node.properties.items(): + type_def = node_to_typedef(module, f"{class_name}__{prop_name}", prop_node) + if prop_name in required: + prop_type = type_def.type + else: + prop_type = Union[(None, type_def.type)] + fields.append((prop_name, prop_type, dataclasses.field(default=type_def.default))) + prop_desc = prop_node.title or prop_node.description + if prop_desc is not None: + params[prop_name] = DocstringParam(prop_name, prop_desc) + + fields.sort(key=lambda t: t[2].default is not dataclasses.MISSING) + if sys.version_info >= (3, 12): + class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__) + else: + class_type = dataclasses.make_dataclass(class_name, fields, namespace={"__module__": module.__name__}) + class_type.__doc__ = str( + Docstring( + short_description=node.title, + long_description=node.description, + params=params, + ) + ) + setattr(module, class_name, class_type) + return TypeDef(class_type, dataclasses.MISSING) + + elif isinstance(node, JsonSchemaOneOf): + union_defs = tuple(node_to_typedef(module, context, n) for n in node.oneOf) + if any(d.default is not dataclasses.MISSING for d in union_defs): + raise TypeError("disallowed: `default` for union member type") + union_types = tuple(d.type for d in union_defs) + return TypeDef(Union[union_types], dataclasses.MISSING) + + raise NotImplementedError() + + +@dataclass +class SchemaFlatteningOptions: + qualified_names: bool = False + recursive: bool = False + + +def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions] = None) -> Schema: + top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema)) + flattener = SchemaFlattener(options) + obj = flattener.flatten(top_node) + return typing.cast(Schema, object_to_json(obj)) + + +class SchemaFlattener: + options: SchemaFlatteningOptions + + def __init__(self, options: Optional[SchemaFlatteningOptions] = None) -> None: + self.options = options or SchemaFlatteningOptions() + + def flatten(self, source_node: JsonSchemaObject) -> JsonSchemaObject: + if source_node.type != "object": + return source_node + + source_props = source_node.properties or {} + target_props: Dict[str, JsonSchemaAny] = {} + + source_reqs = source_node.required or [] + target_reqs: List[str] = [] + + for name, prop in source_props.items(): + if not isinstance(prop, JsonSchemaObject): + target_props[name] = prop + if name in source_reqs: + target_reqs.append(name) + continue + + if self.options.recursive: + obj = self.flatten(prop) + else: + obj = prop + if obj.properties is not None: + if self.options.qualified_names: + target_props.update((f"{name}.{n}", p) for n, p in obj.properties.items()) + else: + target_props.update(obj.properties.items()) + if obj.required is not None: + if self.options.qualified_names: + target_reqs.extend(f"{name}.{n}" for n in obj.required) + else: + target_reqs.extend(obj.required) + + target_node = copy.copy(source_node) + target_node.properties = target_props or None + target_node.additionalProperties = False + target_node.required = target_reqs or None + return target_node diff --git a/llama_stack/strong_typing/core.py b/llama_stack/strong_typing/core.py new file mode 100644 index 000000000..501b6a5db --- /dev/null +++ b/llama_stack/strong_typing/core.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +from typing import Dict, List, Union + + +class JsonObject: + "Placeholder type for an unrestricted JSON object." + + +class JsonArray: + "Placeholder type for an unrestricted JSON array." + + +# a JSON type with possible `null` values +JsonType = Union[ + None, + bool, + int, + float, + str, + Dict[str, "JsonType"], + List["JsonType"], +] + +# a JSON type that cannot contain `null` values +StrictJsonType = Union[ + bool, + int, + float, + str, + Dict[str, "StrictJsonType"], + List["StrictJsonType"], +] + +# a meta-type that captures the object type in a JSON schema +Schema = Dict[str, JsonType] diff --git a/llama_stack/strong_typing/deserializer.py b/llama_stack/strong_typing/deserializer.py new file mode 100644 index 000000000..4c4ee9d89 --- /dev/null +++ b/llama_stack/strong_typing/deserializer.py @@ -0,0 +1,876 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +import abc +import base64 +import dataclasses +import datetime +import enum +import inspect +import ipaddress +import sys +import typing +import uuid +from types import ModuleType +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Literal, + NamedTuple, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from .core import JsonType +from .exception import JsonKeyError, JsonTypeError, JsonValueError +from .inspection import ( + TypeLike, + create_object, + enum_value_types, + evaluate_type, + get_class_properties, + get_class_property, + get_resolved_hints, + is_dataclass_instance, + is_dataclass_type, + is_named_tuple_type, + is_type_annotated, + is_type_literal, + is_type_optional, + unwrap_annotated_type, + unwrap_literal_values, + unwrap_optional_type, +) +from .mapping import python_field_to_json_property +from .name import python_type_to_str + +E = TypeVar("E", bound=enum.Enum) +T = TypeVar("T") +R = TypeVar("R") +K = TypeVar("K") +V = TypeVar("V") + + +class Deserializer(abc.ABC, Generic[T]): + "Parses a JSON value into a Python type." + + def build(self, context: Optional[ModuleType]) -> None: + """ + Creates auxiliary parsers that this parser is depending on. + + :param context: A module context for evaluating types specified as a string. + """ + + @abc.abstractmethod + def parse(self, data: JsonType) -> T: + """ + Parses a JSON value into a Python type. + + :param data: The JSON value to de-serialize. + :returns: The Python object that the JSON value de-serializes to. + """ + + +class NoneDeserializer(Deserializer[None]): + "Parses JSON `null` values into Python `None`." + + def parse(self, data: JsonType) -> None: + if data is not None: + raise JsonTypeError(f"`None` type expects JSON `null` but instead received: {data}") + return None + + +class BoolDeserializer(Deserializer[bool]): + "Parses JSON `boolean` values into Python `bool` type." + + def parse(self, data: JsonType) -> bool: + if not isinstance(data, bool): + raise JsonTypeError(f"`bool` type expects JSON `boolean` data but instead received: {data}") + return bool(data) + + +class IntDeserializer(Deserializer[int]): + "Parses JSON `number` values into Python `int` type." + + def parse(self, data: JsonType) -> int: + if not isinstance(data, int): + raise JsonTypeError(f"`int` type expects integer data as JSON `number` but instead received: {data}") + return int(data) + + +class FloatDeserializer(Deserializer[float]): + "Parses JSON `number` values into Python `float` type." + + def parse(self, data: JsonType) -> float: + if not isinstance(data, float) and not isinstance(data, int): + raise JsonTypeError(f"`int` type expects data as JSON `number` but instead received: {data}") + return float(data) + + +class StringDeserializer(Deserializer[str]): + "Parses JSON `string` values into Python `str` type." + + def parse(self, data: JsonType) -> str: + if not isinstance(data, str): + raise JsonTypeError(f"`str` type expects JSON `string` data but instead received: {data}") + return str(data) + + +class BytesDeserializer(Deserializer[bytes]): + "Parses JSON `string` values of Base64-encoded strings into Python `bytes` type." + + def parse(self, data: JsonType) -> bytes: + if not isinstance(data, str): + raise JsonTypeError(f"`bytes` type expects JSON `string` data but instead received: {data}") + return base64.b64decode(data, validate=True) + + +class DateTimeDeserializer(Deserializer[datetime.datetime]): + "Parses JSON `string` values representing timestamps in ISO 8601 format to Python `datetime` with time zone." + + def parse(self, data: JsonType) -> datetime.datetime: + if not isinstance(data, str): + raise JsonTypeError(f"`datetime` type expects JSON `string` data but instead received: {data}") + + if data.endswith("Z"): + data = f"{data[:-1]}+00:00" # Python's isoformat() does not support military time zones like "Zulu" for UTC + timestamp = datetime.datetime.fromisoformat(data) + if timestamp.tzinfo is None: + raise JsonValueError(f"timestamp lacks explicit time zone designator: {data}") + return timestamp + + +class DateDeserializer(Deserializer[datetime.date]): + "Parses JSON `string` values representing dates in ISO 8601 format to Python `date` type." + + def parse(self, data: JsonType) -> datetime.date: + if not isinstance(data, str): + raise JsonTypeError(f"`date` type expects JSON `string` data but instead received: {data}") + + return datetime.date.fromisoformat(data) + + +class TimeDeserializer(Deserializer[datetime.time]): + "Parses JSON `string` values representing time instances in ISO 8601 format to Python `time` type with time zone." + + def parse(self, data: JsonType) -> datetime.time: + if not isinstance(data, str): + raise JsonTypeError(f"`time` type expects JSON `string` data but instead received: {data}") + + return datetime.time.fromisoformat(data) + + +class UUIDDeserializer(Deserializer[uuid.UUID]): + "Parses JSON `string` values of UUID strings into Python `uuid.UUID` type." + + def parse(self, data: JsonType) -> uuid.UUID: + if not isinstance(data, str): + raise JsonTypeError(f"`UUID` type expects JSON `string` data but instead received: {data}") + return uuid.UUID(data) + + +class IPv4Deserializer(Deserializer[ipaddress.IPv4Address]): + "Parses JSON `string` values of IPv4 address strings into Python `ipaddress.IPv4Address` type." + + def parse(self, data: JsonType) -> ipaddress.IPv4Address: + if not isinstance(data, str): + raise JsonTypeError(f"`IPv4Address` type expects JSON `string` data but instead received: {data}") + return ipaddress.IPv4Address(data) + + +class IPv6Deserializer(Deserializer[ipaddress.IPv6Address]): + "Parses JSON `string` values of IPv6 address strings into Python `ipaddress.IPv6Address` type." + + def parse(self, data: JsonType) -> ipaddress.IPv6Address: + if not isinstance(data, str): + raise JsonTypeError(f"`IPv6Address` type expects JSON `string` data but instead received: {data}") + return ipaddress.IPv6Address(data) + + +class ListDeserializer(Deserializer[List[T]]): + "Recursively de-serializes a JSON array into a Python `list`." + + item_type: Type[T] + item_parser: Deserializer + + def __init__(self, item_type: Type[T]) -> None: + self.item_type = item_type + + def build(self, context: Optional[ModuleType]) -> None: + self.item_parser = _get_deserializer(self.item_type, context) + + def parse(self, data: JsonType) -> List[T]: + if not isinstance(data, list): + type_name = python_type_to_str(self.item_type) + raise JsonTypeError(f"type `List[{type_name}]` expects JSON `array` data but instead received: {data}") + + return [self.item_parser.parse(item) for item in data] + + +class DictDeserializer(Deserializer[Dict[K, V]]): + "Recursively de-serializes a JSON object into a Python `dict`." + + key_type: Type[K] + value_type: Type[V] + value_parser: Deserializer[V] + + def __init__(self, key_type: Type[K], value_type: Type[V]) -> None: + self.key_type = key_type + self.value_type = value_type + self._check_key_type() + + def build(self, context: Optional[ModuleType]) -> None: + self.value_parser = _get_deserializer(self.value_type, context) + + def _check_key_type(self) -> None: + if self.key_type is str: + return + + if issubclass(self.key_type, enum.Enum): + value_types = enum_value_types(self.key_type) + if len(value_types) != 1: + raise JsonTypeError( + f"type `{self.container_type}` has invalid key type, " + f"enumerations must have a consistent member value type but several types found: {value_types}" + ) + value_type = value_types.pop() + if value_type is not str: + f"`type `{self.container_type}` has invalid enumeration key type, expected `enum.Enum` with string values" + return + + raise JsonTypeError( + f"`type `{self.container_type}` has invalid key type, expected `str` or `enum.Enum` with string values" + ) + + @property + def container_type(self) -> str: + key_type_name = python_type_to_str(self.key_type) + value_type_name = python_type_to_str(self.value_type) + return f"Dict[{key_type_name}, {value_type_name}]" + + def parse(self, data: JsonType) -> Dict[K, V]: + if not isinstance(data, dict): + raise JsonTypeError( + f"`type `{self.container_type}` expects JSON `object` data but instead received: {data}" + ) + + return dict( + (self.key_type(key), self.value_parser.parse(value)) # type: ignore[call-arg] + for key, value in data.items() + ) + + +class SetDeserializer(Deserializer[Set[T]]): + "Recursively de-serializes a JSON list into a Python `set`." + + member_type: Type[T] + member_parser: Deserializer + + def __init__(self, member_type: Type[T]) -> None: + self.member_type = member_type + + def build(self, context: Optional[ModuleType]) -> None: + self.member_parser = _get_deserializer(self.member_type, context) + + def parse(self, data: JsonType) -> Set[T]: + if not isinstance(data, list): + type_name = python_type_to_str(self.member_type) + raise JsonTypeError(f"type `Set[{type_name}]` expects JSON `array` data but instead received: {data}") + + return set(self.member_parser.parse(item) for item in data) + + +class TupleDeserializer(Deserializer[Tuple[Any, ...]]): + "Recursively de-serializes a JSON list into a Python `tuple`." + + item_types: Tuple[Type[Any], ...] + item_parsers: Tuple[Deserializer[Any], ...] + + def __init__(self, item_types: Tuple[Type[Any], ...]) -> None: + self.item_types = item_types + + def build(self, context: Optional[ModuleType]) -> None: + self.item_parsers = tuple(_get_deserializer(item_type, context) for item_type in self.item_types) + + @property + def container_type(self) -> str: + type_names = ", ".join(python_type_to_str(item_type) for item_type in self.item_types) + return f"Tuple[{type_names}]" + + def parse(self, data: JsonType) -> Tuple[Any, ...]: + if not isinstance(data, list) or len(data) != len(self.item_parsers): + if not isinstance(data, list): + raise JsonTypeError( + f"type `{self.container_type}` expects JSON `array` data but instead received: {data}" + ) + else: + count = len(self.item_parsers) + raise JsonValueError( + f"type `{self.container_type}` expects a JSON `array` of length {count} but received length {len(data)}" + ) + + return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data)) + + +class UnionDeserializer(Deserializer): + "De-serializes a JSON value (of any type) into a Python union type." + + member_types: Tuple[type, ...] + member_parsers: Tuple[Deserializer, ...] + + def __init__(self, member_types: Tuple[type, ...]) -> None: + self.member_types = member_types + + def build(self, context: Optional[ModuleType]) -> None: + self.member_parsers = tuple(_get_deserializer(member_type, context) for member_type in self.member_types) + + def parse(self, data: JsonType) -> Any: + for member_parser in self.member_parsers: + # iterate over potential types of discriminated union + try: + return member_parser.parse(data) + except (JsonKeyError, JsonTypeError): + # indicates a required field is missing from JSON dict -OR- the data cannot be cast to the expected type, + # i.e. we don't have the type that we are looking for + continue + + type_names = ", ".join(python_type_to_str(member_type) for member_type in self.member_types) + raise JsonKeyError(f"type `Union[{type_names}]` could not be instantiated from: {data}") + + +def get_literal_properties(typ: type) -> Set[str]: + "Returns the names of all properties in a class that are of a literal type." + + return set( + property_name for property_name, property_type in get_class_properties(typ) if is_type_literal(property_type) + ) + + +def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]: + "Returns a set of properties with literal type that are common across all specified classes." + + if not types or not all(isinstance(typ, type) for typ in types): + return set() + + props = get_literal_properties(types[0]) + for typ in types[1:]: + props = props & get_literal_properties(typ) + + return props + + +class TaggedUnionDeserializer(Deserializer): + "De-serializes a JSON value with one or more disambiguating properties into a Python union type." + + member_types: Tuple[type, ...] + disambiguating_properties: Set[str] + member_parsers: Dict[Tuple[str, Any], Deserializer] + + def __init__(self, member_types: Tuple[type, ...]) -> None: + self.member_types = member_types + self.disambiguating_properties = get_discriminating_properties(member_types) + + def build(self, context: Optional[ModuleType]) -> None: + self.member_parsers = {} + for member_type in self.member_types: + for property_name in self.disambiguating_properties: + literal_type = get_class_property(member_type, property_name) + if not literal_type: + continue + + for literal_value in unwrap_literal_values(literal_type): + tpl = (property_name, literal_value) + if tpl in self.member_parsers: + raise JsonTypeError( + f"disambiguating property `{property_name}` in type `{self.union_type}` has a duplicate value: {literal_value}" + ) + + self.member_parsers[tpl] = _get_deserializer(member_type, context) + + @property + def union_type(self) -> str: + type_names = ", ".join(python_type_to_str(member_type) for member_type in self.member_types) + return f"Union[{type_names}]" + + def parse(self, data: JsonType) -> Any: + if not isinstance(data, dict): + raise JsonTypeError( + f"tagged union type `{self.union_type}` expects JSON `object` data but instead received: {data}" + ) + + for property_name in self.disambiguating_properties: + disambiguating_value = data.get(property_name) + if disambiguating_value is None: + continue + + member_parser = self.member_parsers.get((property_name, disambiguating_value)) + if member_parser is None: + raise JsonTypeError( + f"disambiguating property value is invalid for tagged union type `{self.union_type}`: {data}" + ) + + return member_parser.parse(data) + + raise JsonTypeError( + f"disambiguating property value is missing for tagged union type `{self.union_type}`: {data}" + ) + + +class LiteralDeserializer(Deserializer): + "De-serializes a JSON value into a Python literal type." + + values: Tuple[Any, ...] + parser: Deserializer + + def __init__(self, values: Tuple[Any, ...]) -> None: + self.values = values + + def build(self, context: Optional[ModuleType]) -> None: + literal_type_tuple = tuple(type(value) for value in self.values) + literal_type_set = set(literal_type_tuple) + if len(literal_type_set) != 1: + value_names = ", ".join(repr(value) for value in self.values) + raise TypeError( + f"type `Literal[{value_names}]` expects consistent literal value types but got: {literal_type_tuple}" + ) + + literal_type = literal_type_set.pop() + self.parser = _get_deserializer(literal_type, context) + + def parse(self, data: JsonType) -> Any: + value = self.parser.parse(data) + if value not in self.values: + value_names = ", ".join(repr(value) for value in self.values) + raise JsonTypeError(f"type `Literal[{value_names}]` could not be instantiated from: {data}") + return value + + +class EnumDeserializer(Deserializer[E]): + "Returns an enumeration instance based on the enumeration value read from a JSON value." + + enum_type: Type[E] + + def __init__(self, enum_type: Type[E]) -> None: + self.enum_type = enum_type + + def parse(self, data: JsonType) -> E: + return self.enum_type(data) + + +class CustomDeserializer(Deserializer[T]): + "Uses the `from_json` class method in class to de-serialize the object from JSON." + + converter: Callable[[JsonType], T] + + def __init__(self, converter: Callable[[JsonType], T]) -> None: + self.converter = converter + + def parse(self, data: JsonType) -> T: + return self.converter(data) + + +class FieldDeserializer(abc.ABC, Generic[T, R]): + """ + Deserializes a JSON property into a Python object field. + + :param property_name: The name of the JSON property to read from a JSON `object`. + :param field_name: The name of the field in a Python class to write data to. + :param parser: A compatible deserializer that can handle the field's type. + """ + + property_name: str + field_name: str + parser: Deserializer[T] + + def __init__(self, property_name: str, field_name: str, parser: Deserializer[T]) -> None: + self.property_name = property_name + self.field_name = field_name + self.parser = parser + + @abc.abstractmethod + def parse_field(self, data: Dict[str, JsonType]) -> R: ... + + +class RequiredFieldDeserializer(FieldDeserializer[T, T]): + "Deserializes a JSON property into a mandatory Python object field." + + def parse_field(self, data: Dict[str, JsonType]) -> T: + if self.property_name not in data: + raise JsonKeyError(f"missing required property `{self.property_name}` from JSON object: {data}") + + return self.parser.parse(data[self.property_name]) + + +class OptionalFieldDeserializer(FieldDeserializer[T, Optional[T]]): + "Deserializes a JSON property into an optional Python object field with a default value of `None`." + + def parse_field(self, data: Dict[str, JsonType]) -> Optional[T]: + value = data.get(self.property_name) + if value is not None: + return self.parser.parse(value) + else: + return None + + +class DefaultFieldDeserializer(FieldDeserializer[T, T]): + "Deserializes a JSON property into a Python object field with an explicit default value." + + default_value: T + + def __init__( + self, + property_name: str, + field_name: str, + parser: Deserializer, + default_value: T, + ) -> None: + super().__init__(property_name, field_name, parser) + self.default_value = default_value + + def parse_field(self, data: Dict[str, JsonType]) -> T: + value = data.get(self.property_name) + if value is not None: + return self.parser.parse(value) + else: + return self.default_value + + +class DefaultFactoryFieldDeserializer(FieldDeserializer[T, T]): + "Deserializes a JSON property into an optional Python object field with an explicit default value factory." + + default_factory: Callable[[], T] + + def __init__( + self, + property_name: str, + field_name: str, + parser: Deserializer[T], + default_factory: Callable[[], T], + ) -> None: + super().__init__(property_name, field_name, parser) + self.default_factory = default_factory + + def parse_field(self, data: Dict[str, JsonType]) -> T: + value = data.get(self.property_name) + if value is not None: + return self.parser.parse(value) + else: + return self.default_factory() + + +class ClassDeserializer(Deserializer[T]): + "Base class for de-serializing class-like types such as data classes, named tuples and regular classes." + + class_type: type + property_parsers: List[FieldDeserializer] + property_fields: Set[str] + + def __init__(self, class_type: Type[T]) -> None: + self.class_type = class_type + + def assign(self, property_parsers: List[FieldDeserializer]) -> None: + self.property_parsers = property_parsers + self.property_fields = set(property_parser.property_name for property_parser in property_parsers) + + def parse(self, data: JsonType) -> T: + if not isinstance(data, dict): + type_name = python_type_to_str(self.class_type) + raise JsonTypeError(f"`type `{type_name}` expects JSON `object` data but instead received: {data}") + + object_data: Dict[str, JsonType] = typing.cast(Dict[str, JsonType], data) + + field_values = {} + for property_parser in self.property_parsers: + field_values[property_parser.field_name] = property_parser.parse_field(object_data) + + if not self.property_fields.issuperset(object_data): + unassigned_names = [name for name in object_data if name not in self.property_fields] + raise JsonKeyError(f"unrecognized fields in JSON object: {unassigned_names}") + + return self.create(**field_values) + + def create(self, **field_values: Any) -> T: + "Instantiates an object with a collection of property values." + + obj: T = create_object(self.class_type) + + # use `setattr` on newly created object instance + for field_name, field_value in field_values.items(): + setattr(obj, field_name, field_value) + return obj + + +class NamedTupleDeserializer(ClassDeserializer[NamedTuple]): + "De-serializes a named tuple from a JSON `object`." + + def build(self, context: Optional[ModuleType]) -> None: + property_parsers: List[FieldDeserializer] = [ + RequiredFieldDeserializer(field_name, field_name, _get_deserializer(field_type, context)) + for field_name, field_type in get_resolved_hints(self.class_type).items() + ] + super().assign(property_parsers) + + def create(self, **field_values: Any) -> NamedTuple: + return self.class_type(**field_values) + + +class DataclassDeserializer(ClassDeserializer[T]): + "De-serializes a data class from a JSON `object`." + + def __init__(self, class_type: Type[T]) -> None: + if not dataclasses.is_dataclass(class_type): + raise TypeError("expected: data-class type") + super().__init__(class_type) # type: ignore[arg-type] + + def build(self, context: Optional[ModuleType]) -> None: + property_parsers: List[FieldDeserializer] = [] + resolved_hints = get_resolved_hints(self.class_type) + for field in dataclasses.fields(self.class_type): + field_type = resolved_hints[field.name] + property_name = python_field_to_json_property(field.name, field_type) + + is_optional = is_type_optional(field_type) + has_default = field.default is not dataclasses.MISSING + has_default_factory = field.default_factory is not dataclasses.MISSING + + if is_optional: + required_type: Type[T] = unwrap_optional_type(field_type) + else: + required_type = field_type + + parser = _get_deserializer(required_type, context) + + if has_default: + field_parser: FieldDeserializer = DefaultFieldDeserializer( + property_name, field.name, parser, field.default + ) + elif has_default_factory: + default_factory = typing.cast(Callable[[], Any], field.default_factory) + field_parser = DefaultFactoryFieldDeserializer(property_name, field.name, parser, default_factory) + elif is_optional: + field_parser = OptionalFieldDeserializer(property_name, field.name, parser) + else: + field_parser = RequiredFieldDeserializer(property_name, field.name, parser) + + property_parsers.append(field_parser) + + super().assign(property_parsers) + + +class FrozenDataclassDeserializer(DataclassDeserializer[T]): + "De-serializes a frozen data class from a JSON `object`." + + def create(self, **field_values: Any) -> T: + "Instantiates an object with a collection of property values." + + # create object instance without calling `__init__` + obj: T = create_object(self.class_type) + + # can't use `setattr` on frozen dataclasses, pass member variable values to `__init__` + obj.__init__(**field_values) # type: ignore + return obj + + +class TypedClassDeserializer(ClassDeserializer[T]): + "De-serializes a class with type annotations from a JSON `object` by iterating over class properties." + + def build(self, context: Optional[ModuleType]) -> None: + property_parsers: List[FieldDeserializer] = [] + for field_name, field_type in get_resolved_hints(self.class_type).items(): + property_name = python_field_to_json_property(field_name, field_type) + + is_optional = is_type_optional(field_type) + + if is_optional: + required_type: Type[T] = unwrap_optional_type(field_type) + else: + required_type = field_type + + parser = _get_deserializer(required_type, context) + + if is_optional: + field_parser: FieldDeserializer = OptionalFieldDeserializer(property_name, field_name, parser) + else: + field_parser = RequiredFieldDeserializer(property_name, field_name, parser) + + property_parsers.append(field_parser) + + super().assign(property_parsers) + + +def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Deserializer: + """ + Creates a de-serializer engine to produce a Python object from an object obtained from a JSON string. + + When de-serializing a JSON object into a Python object, the following transformations are applied: + + * Fundamental types are parsed as `bool`, `int`, `float` or `str`. + * Date and time types are parsed from the ISO 8601 format with time zone into the corresponding Python type + `datetime`, `date` or `time`. + * Byte arrays are read from a string with Base64 encoding into a `bytes` instance. + * UUIDs are extracted from a UUID string compliant with RFC 4122 into a `uuid.UUID` instance. + * Enumerations are instantiated with a lookup on enumeration value. + * Containers (e.g. `list`, `dict`, `set`, `tuple`) are parsed recursively. + * Complex objects with properties (including data class types) are populated from dictionaries of key-value pairs + using reflection (enumerating type annotations). + + :raises TypeError: A de-serializer engine cannot be constructed for the input type. + """ + + if context is None: + if isinstance(typ, type): + context = sys.modules[typ.__module__] + + return _get_deserializer(typ, context) + + +_CACHE: Dict[Tuple[str, str], Deserializer] = {} + + +def _get_deserializer(typ: TypeLike, context: Optional[ModuleType]) -> Deserializer: + "Creates or re-uses a de-serializer engine to parse an object obtained from a JSON string." + + cache_key = None + + if isinstance(typ, (str, typing.ForwardRef)): + if context is None: + raise TypeError(f"missing context for evaluating type: {typ}") + + if isinstance(typ, str): + if hasattr(context, typ): + cache_key = (context.__name__, typ) + elif isinstance(typ, typing.ForwardRef): + if hasattr(context, typ.__forward_arg__): + cache_key = (context.__name__, typ.__forward_arg__) + + typ = evaluate_type(typ, context) + + typ = unwrap_annotated_type(typ) if is_type_annotated(typ) else typ + + if isinstance(typ, type) and typing.get_origin(typ) is None: + cache_key = (typ.__module__, typ.__name__) + + if cache_key is not None: + deserializer = _CACHE.get(cache_key) + if deserializer is None: + deserializer = _create_deserializer(typ) + + # store de-serializer immediately in cache to avoid stack overflow for recursive types + _CACHE[cache_key] = deserializer + + if isinstance(typ, type): + # use type's own module as context for evaluating member types + context = sys.modules[typ.__module__] + + # create any de-serializers this de-serializer is depending on + deserializer.build(context) + else: + # special forms are not always hashable, create a new de-serializer every time + deserializer = _create_deserializer(typ) + deserializer.build(context) + + return deserializer + + +def _create_deserializer(typ: TypeLike) -> Deserializer: + "Creates a de-serializer engine to parse an object obtained from a JSON string." + + # check for well-known types + if typ is type(None): + return NoneDeserializer() + elif typ is bool: + return BoolDeserializer() + elif typ is int: + return IntDeserializer() + elif typ is float: + return FloatDeserializer() + elif typ is str: + return StringDeserializer() + elif typ is bytes: + return BytesDeserializer() + elif typ is datetime.datetime: + return DateTimeDeserializer() + elif typ is datetime.date: + return DateDeserializer() + elif typ is datetime.time: + return TimeDeserializer() + elif typ is uuid.UUID: + return UUIDDeserializer() + elif typ is ipaddress.IPv4Address: + return IPv4Deserializer() + elif typ is ipaddress.IPv6Address: + return IPv6Deserializer() + + # dynamically-typed collection types + if typ is list: + raise TypeError("explicit item type required: use `List[T]` instead of `list`") + if typ is dict: + raise TypeError("explicit key and value types required: use `Dict[K, V]` instead of `dict`") + if typ is set: + raise TypeError("explicit member type required: use `Set[T]` instead of `set`") + if typ is tuple: + raise TypeError("explicit item type list required: use `Tuple[T, ...]` instead of `tuple`") + + # generic types (e.g. list, dict, set, etc.) + origin_type = typing.get_origin(typ) + if origin_type is list: + (list_item_type,) = typing.get_args(typ) # unpack single tuple element + return ListDeserializer(list_item_type) + elif origin_type is dict: + key_type, value_type = typing.get_args(typ) + return DictDeserializer(key_type, value_type) + elif origin_type is set: + (set_member_type,) = typing.get_args(typ) # unpack single tuple element + return SetDeserializer(set_member_type) + elif origin_type is tuple: + return TupleDeserializer(typing.get_args(typ)) + elif origin_type is Union: + union_args = typing.get_args(typ) + if get_discriminating_properties(union_args): + return TaggedUnionDeserializer(union_args) + else: + return UnionDeserializer(union_args) + elif origin_type is Literal: + return LiteralDeserializer(typing.get_args(typ)) + + if not inspect.isclass(typ): + if is_dataclass_instance(typ): + raise TypeError(f"dataclass type expected but got instance: {typ}") + else: + raise TypeError(f"unable to de-serialize unrecognized type: {typ}") + + if issubclass(typ, enum.Enum): + return EnumDeserializer(typ) + + if is_named_tuple_type(typ): + return NamedTupleDeserializer(typ) + + # check if object has custom serialization method + convert_func = getattr(typ, "from_json", None) + if callable(convert_func): + return CustomDeserializer(convert_func) + + if is_dataclass_type(typ): + dataclass_params = getattr(typ, "__dataclass_params__", None) + if dataclass_params is not None and dataclass_params.frozen: + return FrozenDataclassDeserializer(typ) + else: + return DataclassDeserializer(typ) + + return TypedClassDeserializer(typ) diff --git a/llama_stack/strong_typing/docstring.py b/llama_stack/strong_typing/docstring.py new file mode 100644 index 000000000..9169aadfe --- /dev/null +++ b/llama_stack/strong_typing/docstring.py @@ -0,0 +1,399 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +import builtins +import dataclasses +import inspect +import re +import sys +import types +import typing +from dataclasses import dataclass +from io import StringIO +from typing import Any, Callable, Dict, Optional, Protocol, Type, TypeVar + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +from .inspection import ( + DataclassInstance, + get_class_properties, + get_signature, + is_dataclass_type, + is_type_enum, +) + +T = TypeVar("T") + + +@dataclass +class DocstringParam: + """ + A parameter declaration in a parameter block. + + :param name: The name of the parameter. + :param description: The description text for the parameter. + """ + + name: str + description: str + param_type: type = inspect.Signature.empty + + def __str__(self) -> str: + return f":param {self.name}: {self.description}" + + +@dataclass +class DocstringReturns: + """ + A `returns` declaration extracted from a docstring. + + :param description: The description text for the return value. + """ + + description: str + return_type: type = inspect.Signature.empty + + def __str__(self) -> str: + return f":returns: {self.description}" + + +@dataclass +class DocstringRaises: + """ + A `raises` declaration extracted from a docstring. + + :param typename: The type name of the exception raised. + :param description: The description associated with the exception raised. + """ + + typename: str + description: str + raise_type: type = inspect.Signature.empty + + def __str__(self) -> str: + return f":raises {self.typename}: {self.description}" + + +@dataclass +class Docstring: + """ + Represents the documentation string (a.k.a. docstring) for a type such as a (data) class or function. + + A docstring is broken down into the following components: + * A short description, which is the first block of text in the documentation string, and ends with a double + newline or a parameter block. + * A long description, which is the optional block of text following the short description, and ends with + a parameter block. + * A parameter block of named parameter and description string pairs in ReST-style. + * A `returns` declaration, which adds explanation to the return value. + * A `raises` declaration, which adds explanation to the exception type raised by the function on error. + + When the docstring is attached to a data class, it is understood as the documentation string of the class + `__init__` method. + + :param short_description: The short description text parsed from a docstring. + :param long_description: The long description text parsed from a docstring. + :param params: The parameter block extracted from a docstring. + :param returns: The returns declaration extracted from a docstring. + """ + + short_description: Optional[str] = None + long_description: Optional[str] = None + params: Dict[str, DocstringParam] = dataclasses.field(default_factory=dict) + returns: Optional[DocstringReturns] = None + raises: Dict[str, DocstringRaises] = dataclasses.field(default_factory=dict) + + @property + def full_description(self) -> Optional[str]: + if self.short_description and self.long_description: + return f"{self.short_description}\n\n{self.long_description}" + elif self.short_description: + return self.short_description + else: + return None + + def __str__(self) -> str: + output = StringIO() + + has_description = self.short_description or self.long_description + has_blocks = self.params or self.returns or self.raises + + if has_description: + if self.short_description and self.long_description: + output.write(self.short_description) + output.write("\n\n") + output.write(self.long_description) + elif self.short_description: + output.write(self.short_description) + + if has_blocks: + if has_description: + output.write("\n") + + for param in self.params.values(): + output.write("\n") + output.write(str(param)) + if self.returns: + output.write("\n") + output.write(str(self.returns)) + for raises in self.raises.values(): + output.write("\n") + output.write(str(raises)) + + s = output.getvalue() + output.close() + return s + + +def is_exception(member: object) -> TypeGuard[Type[BaseException]]: + return isinstance(member, type) and issubclass(member, BaseException) + + +def get_exceptions(module: types.ModuleType) -> Dict[str, Type[BaseException]]: + "Returns all exception classes declared in a module." + + return {name: class_type for name, class_type in inspect.getmembers(module, is_exception)} + + +class SupportsDoc(Protocol): + __doc__: Optional[str] + + +def parse_type(typ: SupportsDoc) -> Docstring: + """ + Parse the docstring of a type into its components. + + :param typ: The type whose documentation string to parse. + :returns: Components of the documentation string. + """ + + doc = get_docstring(typ) + if doc is None: + return Docstring() + + docstring = parse_text(doc) + check_docstring(typ, docstring) + + # assign parameter and return types + if is_dataclass_type(typ): + properties = dict(get_class_properties(typing.cast(type, typ))) + + for name, param in docstring.params.items(): + param.param_type = properties[name] + + elif inspect.isfunction(typ): + signature = get_signature(typ) + for name, param in docstring.params.items(): + param.param_type = signature.parameters[name].annotation + if docstring.returns: + docstring.returns.return_type = signature.return_annotation + + # assign exception types + defining_module = inspect.getmodule(typ) + if defining_module: + context: Dict[str, type] = {} + context.update(get_exceptions(builtins)) + context.update(get_exceptions(defining_module)) + for exc_name, exc in docstring.raises.items(): + raise_type = context.get(exc_name) + if raise_type is None: + type_name = getattr(typ, "__qualname__", None) or getattr(typ, "__name__", None) or None + raise TypeError( + f"doc-string exception type `{exc_name}` is not an exception defined in the context of `{type_name}`" + ) + + exc.raise_type = raise_type + + return docstring + + +def parse_text(text: str) -> Docstring: + """ + Parse a ReST-style docstring into its components. + + :param text: The documentation string to parse, typically acquired as `type.__doc__`. + :returns: Components of the documentation string. + """ + + if not text: + return Docstring() + + # find block that starts object metadata block (e.g. `:param p:` or `:returns:`) + text = inspect.cleandoc(text) + match = re.search("^:", text, flags=re.MULTILINE) + if match: + desc_chunk = text[: match.start()] + meta_chunk = text[match.start() :] # noqa: E203 + else: + desc_chunk = text + meta_chunk = "" + + # split description text into short and long description + parts = desc_chunk.split("\n\n", 1) + + # ensure short description has no newlines + short_description = parts[0].strip().replace("\n", " ") or None + + # ensure long description preserves its structure (e.g. preformatted text) + if len(parts) > 1: + long_description = parts[1].strip() or None + else: + long_description = None + + params: Dict[str, DocstringParam] = {} + raises: Dict[str, DocstringRaises] = {} + returns = None + for match in re.finditer(r"(^:.*?)(?=^:|\Z)", meta_chunk, flags=re.DOTALL | re.MULTILINE): + chunk = match.group(0) + if not chunk: + continue + + args_chunk, desc_chunk = chunk.lstrip(":").split(":", 1) + args = args_chunk.split() + desc = re.sub(r"\s+", " ", desc_chunk.strip()) + + if len(args) > 0: + kw = args[0] + if len(args) == 2: + if kw == "param": + params[args[1]] = DocstringParam( + name=args[1], + description=desc, + ) + elif kw == "raise" or kw == "raises": + raises[args[1]] = DocstringRaises( + typename=args[1], + description=desc, + ) + + elif len(args) == 1: + if kw == "return" or kw == "returns": + returns = DocstringReturns(description=desc) + + return Docstring( + long_description=long_description, + short_description=short_description, + params=params, + returns=returns, + raises=raises, + ) + + +def has_default_docstring(typ: SupportsDoc) -> bool: + "Check if class has the auto-generated string assigned by @dataclass." + + if not isinstance(typ, type): + return False + + if is_dataclass_type(typ): + return typ.__doc__ is not None and re.match(f"^{re.escape(typ.__name__)}[(].*[)]$", typ.__doc__) is not None + + if is_type_enum(typ): + return typ.__doc__ is not None and typ.__doc__ == "An enumeration." + + return False + + +def has_docstring(typ: SupportsDoc) -> bool: + "Check if class has a documentation string other than the auto-generated string assigned by @dataclass." + + if has_default_docstring(typ): + return False + + return bool(typ.__doc__) + + +def get_docstring(typ: SupportsDoc) -> Optional[str]: + if typ.__doc__ is None: + return None + + if has_default_docstring(typ): + return None + + return typ.__doc__ + + +def check_docstring(typ: SupportsDoc, docstring: Docstring, strict: bool = False) -> None: + """ + Verifies the doc-string of a type. + + :raises TypeError: Raised on a mismatch between doc-string parameters, and function or type signature. + """ + + if is_dataclass_type(typ): + check_dataclass_docstring(typ, docstring, strict) + elif inspect.isfunction(typ): + check_function_docstring(typ, docstring, strict) + + +def check_dataclass_docstring(typ: Type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None: + """ + Verifies the doc-string of a data-class type. + + :param strict: Whether to check if all data-class members have doc-strings. + :raises TypeError: Raised on a mismatch between doc-string parameters and data-class members. + """ + + if not is_dataclass_type(typ): + raise TypeError("not a data-class type") + + properties = dict(get_class_properties(typ)) + class_name = typ.__name__ + + for name in docstring.params: + if name not in properties: + raise TypeError(f"doc-string parameter `{name}` is not a member of the data-class `{class_name}`") + + if not strict: + return + + for name in properties: + if name not in docstring.params: + raise TypeError(f"member `{name}` in data-class `{class_name}` is missing its doc-string") + + +def check_function_docstring(fn: Callable[..., Any], docstring: Docstring, strict: bool = False) -> None: + """ + Verifies the doc-string of a function or member function. + + :param strict: Whether to check if all function parameters and the return type have doc-strings. + :raises TypeError: Raised on a mismatch between doc-string parameters and function signature. + """ + + signature = get_signature(fn) + func_name = fn.__qualname__ + + for name in docstring.params: + if name not in signature.parameters: + raise TypeError(f"doc-string parameter `{name}` is absent from signature of function `{func_name}`") + + if docstring.returns is not None and signature.return_annotation is inspect.Signature.empty: + raise TypeError(f"doc-string has returns description in function `{func_name}` with no return type annotation") + + if not strict: + return + + for name, param in signature.parameters.items(): + # ignore `self` in member function signatures + if name == "self" and ( + param.kind is inspect.Parameter.POSITIONAL_ONLY or param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD + ): + continue + + if name not in docstring.params: + raise TypeError(f"function parameter `{name}` in `{func_name}` is missing its doc-string") + + if signature.return_annotation is not inspect.Signature.empty and docstring.returns is None: + raise TypeError(f"function `{func_name}` has no returns description in its doc-string") diff --git a/llama_stack/strong_typing/exception.py b/llama_stack/strong_typing/exception.py new file mode 100644 index 000000000..af037cc3c --- /dev/null +++ b/llama_stack/strong_typing/exception.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + + +class JsonKeyError(Exception): + "Raised when deserialization for a class or union type has failed because a matching member was not found." + + +class JsonValueError(Exception): + "Raised when (de)serialization of data has failed due to invalid value." + + +class JsonTypeError(Exception): + "Raised when deserialization of data has failed due to a type mismatch." diff --git a/llama_stack/strong_typing/inspection.py b/llama_stack/strong_typing/inspection.py new file mode 100644 index 000000000..69bc15597 --- /dev/null +++ b/llama_stack/strong_typing/inspection.py @@ -0,0 +1,1034 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +import dataclasses +import datetime +import enum +import importlib +import importlib.machinery +import importlib.util +import inspect +import re +import sys +import types +import typing +import uuid +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + NamedTuple, + Optional, + Protocol, + Set, + Tuple, + Type, + TypeVar, + Union, + runtime_checkable, +) + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +S = TypeVar("S") +T = TypeVar("T") +K = TypeVar("K") +V = TypeVar("V") + + +def _is_type_like(data_type: object) -> bool: + """ + Checks if the object is a type or type-like object (e.g. generic type). + + :param data_type: The object to validate. + :returns: True if the object is a type or type-like object. + """ + + if isinstance(data_type, type): + # a standard type + return True + elif typing.get_origin(data_type) is not None: + # a generic type such as `list`, `dict` or `set` + return True + elif hasattr(data_type, "__forward_arg__"): + # an instance of `ForwardRef` + return True + elif data_type is Any: + # the special form `Any` + return True + else: + return False + + +if sys.version_info >= (3, 9): + TypeLike = Union[type, types.GenericAlias, typing.ForwardRef, Any] + + def is_type_like( + data_type: object, + ) -> TypeGuard[TypeLike]: + """ + Checks if the object is a type or type-like object (e.g. generic type). + + :param data_type: The object to validate. + :returns: True if the object is a type or type-like object. + """ + + return _is_type_like(data_type) + +else: + TypeLike = object + + def is_type_like( + data_type: object, + ) -> bool: + return _is_type_like(data_type) + + +def evaluate_member_type(typ: Any, cls: type) -> Any: + """ + Evaluates a forward reference type in a dataclass member. + + :param typ: The dataclass member type to convert. + :param cls: The dataclass in which the member is defined. + :returns: The evaluated type. + """ + + return evaluate_type(typ, sys.modules[cls.__module__]) + + +def evaluate_type(typ: Any, module: types.ModuleType) -> Any: + """ + Evaluates a forward reference type. + + :param typ: The type to convert, typically a dataclass member type. + :param module: The context for the type, i.e. the module in which the member is defined. + :returns: The evaluated type. + """ + + if isinstance(typ, str): + # evaluate data-class field whose type annotation is a string + return eval(typ, module.__dict__, locals()) + if isinstance(typ, typing.ForwardRef): + if sys.version_info >= (3, 9): + return typ._evaluate(module.__dict__, locals(), recursive_guard=frozenset()) + else: + return typ._evaluate(module.__dict__, locals()) + else: + return typ + + +@runtime_checkable +class DataclassInstance(Protocol): + __dataclass_fields__: typing.ClassVar[Dict[str, dataclasses.Field]] + + +def is_dataclass_type(typ: Any) -> TypeGuard[Type[DataclassInstance]]: + "True if the argument corresponds to a data class type (but not an instance)." + + typ = unwrap_annotated_type(typ) + return isinstance(typ, type) and dataclasses.is_dataclass(typ) + + +def is_dataclass_instance(obj: Any) -> TypeGuard[DataclassInstance]: + "True if the argument corresponds to a data class instance (but not a type)." + + return not isinstance(obj, type) and dataclasses.is_dataclass(obj) + + +@dataclasses.dataclass +class DataclassField: + name: str + type: Any + default: Any + + def __init__(self, name: str, type: Any, default: Any = dataclasses.MISSING) -> None: + self.name = name + self.type = type + self.default = default + + +def dataclass_fields(cls: Type[DataclassInstance]) -> Iterable[DataclassField]: + "Generates the fields of a data-class resolving forward references." + + for field in dataclasses.fields(cls): + yield DataclassField(field.name, evaluate_member_type(field.type, cls), field.default) + + +def dataclass_field_by_name(cls: Type[DataclassInstance], name: str) -> DataclassField: + "Looks up a field in a data-class by its field name." + + for field in dataclasses.fields(cls): + if field.name == name: + return DataclassField(field.name, evaluate_member_type(field.type, cls)) + + raise LookupError(f"field `{name}` missing from class `{cls.__name__}`") + + +def is_named_tuple_instance(obj: Any) -> TypeGuard[NamedTuple]: + "True if the argument corresponds to a named tuple instance." + + return is_named_tuple_type(type(obj)) + + +def is_named_tuple_type(typ: Any) -> TypeGuard[Type[NamedTuple]]: + """ + True if the argument corresponds to a named tuple type. + + Calling the function `collections.namedtuple` gives a new type that is a subclass of `tuple` (and no other classes) + with a member named `_fields` that is a tuple whose items are all strings. + """ + + if not isinstance(typ, type): + return False + + typ = unwrap_annotated_type(typ) + + b = getattr(typ, "__bases__", None) + if b is None: + return False + + if len(b) != 1 or b[0] != tuple: + return False + + f = getattr(typ, "_fields", None) + if not isinstance(f, tuple): + return False + + return all(isinstance(n, str) for n in f) + + +if sys.version_info >= (3, 11): + + def is_type_enum(typ: object) -> TypeGuard[Type[enum.Enum]]: + "True if the specified type is an enumeration type." + + typ = unwrap_annotated_type(typ) + return isinstance(typ, enum.EnumType) + +else: + + def is_type_enum(typ: object) -> TypeGuard[Type[enum.Enum]]: + "True if the specified type is an enumeration type." + + typ = unwrap_annotated_type(typ) + + # use an explicit isinstance(..., type) check to filter out special forms like generics + return isinstance(typ, type) and issubclass(typ, enum.Enum) + + +def enum_value_types(enum_type: Type[enum.Enum]) -> List[type]: + """ + Returns all unique value types of the `enum.Enum` type in definition order. + """ + + # filter unique enumeration value types by keeping definition order + return list(dict.fromkeys(type(e.value) for e in enum_type)) + + +def extend_enum( + source: Type[enum.Enum], +) -> Callable[[Type[enum.Enum]], Type[enum.Enum]]: + """ + Creates a new enumeration type extending the set of values in an existing type. + + :param source: The existing enumeration type to be extended with new values. + :returns: A new enumeration type with the extended set of values. + """ + + def wrap(extend: Type[enum.Enum]) -> Type[enum.Enum]: + # create new enumeration type combining the values from both types + values: Dict[str, Any] = {} + values.update((e.name, e.value) for e in source) + values.update((e.name, e.value) for e in extend) + enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore + + # assign the newly created type to the same module where the extending class is defined + setattr(enum_class, "__module__", extend.__module__) + setattr(enum_class, "__doc__", extend.__doc__) + setattr(sys.modules[extend.__module__], extend.__name__, enum_class) + + return enum.unique(enum_class) + + return wrap + + +if sys.version_info >= (3, 10): + + def _is_union_like(typ: object) -> bool: + "True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`." + + return typing.get_origin(typ) is Union or isinstance(typ, types.UnionType) + +else: + + def _is_union_like(typ: object) -> bool: + "True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`." + + return typing.get_origin(typ) is Union + + +def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[Type[Optional[Any]]]: + """ + True if the type annotation corresponds to an optional type (e.g. `Optional[T]` or `Union[T1,T2,None]`). + + `Optional[T]` is represented as `Union[T, None]` is classic style, and is equivalent to `T | None` in new style. + + :param strict: True if only `Optional[T]` qualifies as an optional type but `Union[T1, T2, None]` does not. + """ + + typ = unwrap_annotated_type(typ) + + if _is_union_like(typ): + args = typing.get_args(typ) + if strict and len(args) != 2: + return False + + return type(None) in args + + return False + + +def unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]: + """ + Extracts the inner type of an optional type. + + :param typ: The optional type `Optional[T]`. + :returns: The inner type `T`. + """ + + return rewrap_annotated_type(_unwrap_optional_type, typ) + + +def _unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]: + "Extracts the type qualified as optional (e.g. returns `T` for `Optional[T]`)." + + # Optional[T] is represented internally as Union[T, None] + if not _is_union_like(typ): + raise TypeError("optional type must have un-subscripted type of Union") + + # will automatically unwrap Union[T] into T + return Union[ + tuple(filter(lambda item: item is not type(None), typing.get_args(typ))) # type: ignore + ] + + +def is_type_union(typ: object) -> bool: + "True if the type annotation corresponds to a union type (e.g. `Union[T1,T2,T3]`)." + + typ = unwrap_annotated_type(typ) + if _is_union_like(typ): + args = typing.get_args(typ) + return len(args) > 2 or type(None) not in args + + return False + + +def unwrap_union_types(typ: object) -> Tuple[object, ...]: + """ + Extracts the inner types of a union type. + + :param typ: The union type `Union[T1, T2, ...]`. + :returns: The inner types `T1`, `T2`, etc. + """ + + typ = unwrap_annotated_type(typ) + return _unwrap_union_types(typ) + + +def _unwrap_union_types(typ: object) -> Tuple[object, ...]: + "Extracts the types in a union (e.g. returns a tuple of types `T1` and `T2` for `Union[T1, T2]`)." + + if not _is_union_like(typ): + raise TypeError("union type must have un-subscripted type of Union") + + return typing.get_args(typ) + + +def is_type_literal(typ: object) -> bool: + "True if the specified type is a literal of one or more constant values, e.g. `Literal['string']` or `Literal[42]`." + + typ = unwrap_annotated_type(typ) + return typing.get_origin(typ) is Literal + + +def unwrap_literal_value(typ: object) -> Any: + """ + Extracts the single constant value captured by a literal type. + + :param typ: The literal type `Literal[value]`. + :returns: The values captured by the literal type. + """ + + args = unwrap_literal_values(typ) + if len(args) != 1: + raise TypeError("too many values in literal type") + + return args[0] + + +def unwrap_literal_values(typ: object) -> Tuple[Any, ...]: + """ + Extracts the constant values captured by a literal type. + + :param typ: The literal type `Literal[value, ...]`. + :returns: A tuple of values captured by the literal type. + """ + + typ = unwrap_annotated_type(typ) + return typing.get_args(typ) + + +def unwrap_literal_types(typ: object) -> Tuple[type, ...]: + """ + Extracts the types of the constant values captured by a literal type. + + :param typ: The literal type `Literal[value, ...]`. + :returns: A tuple of item types `T` such that `type(value) == T`. + """ + + return tuple(type(t) for t in unwrap_literal_values(typ)) + + +def is_generic_list(typ: object) -> TypeGuard[Type[list]]: + "True if the specified type is a generic list, i.e. `List[T]`." + + typ = unwrap_annotated_type(typ) + return typing.get_origin(typ) is list + + +def unwrap_generic_list(typ: Type[List[T]]) -> Type[T]: + """ + Extracts the item type of a list type. + + :param typ: The list type `List[T]`. + :returns: The item type `T`. + """ + + return rewrap_annotated_type(_unwrap_generic_list, typ) + + +def _unwrap_generic_list(typ: Type[List[T]]) -> Type[T]: + "Extracts the item type of a list type (e.g. returns `T` for `List[T]`)." + + (list_type,) = typing.get_args(typ) # unpack single tuple element + return list_type + + +def is_generic_set(typ: object) -> TypeGuard[Type[set]]: + "True if the specified type is a generic set, i.e. `Set[T]`." + + typ = unwrap_annotated_type(typ) + return typing.get_origin(typ) is set + + +def unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]: + """ + Extracts the item type of a set type. + + :param typ: The set type `Set[T]`. + :returns: The item type `T`. + """ + + return rewrap_annotated_type(_unwrap_generic_set, typ) + + +def _unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]: + "Extracts the item type of a set type (e.g. returns `T` for `Set[T]`)." + + (set_type,) = typing.get_args(typ) # unpack single tuple element + return set_type + + +def is_generic_dict(typ: object) -> TypeGuard[Type[dict]]: + "True if the specified type is a generic dictionary, i.e. `Dict[KeyType, ValueType]`." + + typ = unwrap_annotated_type(typ) + return typing.get_origin(typ) is dict + + +def unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]: + """ + Extracts the key and value types of a dictionary type as a tuple. + + :param typ: The dictionary type `Dict[K, V]`. + :returns: The key and value types `K` and `V`. + """ + + return _unwrap_generic_dict(unwrap_annotated_type(typ)) + + +def _unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]: + "Extracts the key and value types of a dict type (e.g. returns (`K`, `V`) for `Dict[K, V]`)." + + key_type, value_type = typing.get_args(typ) + return key_type, value_type + + +def is_type_annotated(typ: TypeLike) -> bool: + "True if the type annotation corresponds to an annotated type (i.e. `Annotated[T, ...]`)." + + return getattr(typ, "__metadata__", None) is not None + + +def get_annotation(data_type: TypeLike, annotation_type: Type[T]) -> Optional[T]: + """ + Returns the first annotation on a data type that matches the expected annotation type. + + :param data_type: The annotated type from which to extract the annotation. + :param annotation_type: The annotation class to look for. + :returns: The annotation class instance found (if any). + """ + + metadata = getattr(data_type, "__metadata__", None) + if metadata is not None: + for annotation in metadata: + if isinstance(annotation, annotation_type): + return annotation + + return None + + +def unwrap_annotated_type(typ: T) -> T: + "Extracts the wrapped type from an annotated type (e.g. returns `T` for `Annotated[T, ...]`)." + + if is_type_annotated(typ): + # type is Annotated[T, ...] + return typing.get_args(typ)[0] + else: + # type is a regular type + return typ + + +def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S]) -> Type[T]: + """ + Un-boxes, transforms and re-boxes an optionally annotated type. + + :param transform: A function that maps an un-annotated type to another type. + :param typ: A type to un-box (if necessary), transform, and re-box (if necessary). + """ + + metadata = getattr(typ, "__metadata__", None) + if metadata is not None: + # type is Annotated[T, ...] + inner_type = typing.get_args(typ)[0] + else: + # type is a regular type + inner_type = typ + + transformed_type = transform(inner_type) + + if metadata is not None: + return Annotated[(transformed_type, *metadata)] # type: ignore + else: + return transformed_type + + +def get_module_classes(module: types.ModuleType) -> List[type]: + "Returns all classes declared directly in a module." + + def is_class_member(member: object) -> TypeGuard[type]: + return inspect.isclass(member) and member.__module__ == module.__name__ + + return [class_type for _, class_type in inspect.getmembers(module, is_class_member)] + + +if sys.version_info >= (3, 9): + + def get_resolved_hints(typ: type) -> Dict[str, type]: + return typing.get_type_hints(typ, include_extras=True) + +else: + + def get_resolved_hints(typ: type) -> Dict[str, type]: + return typing.get_type_hints(typ) + + +def get_class_properties(typ: type) -> Iterable[Tuple[str, type]]: + "Returns all properties of a class." + + if is_dataclass_type(typ): + return ((field.name, field.type) for field in dataclasses.fields(typ)) + else: + resolved_hints = get_resolved_hints(typ) + return resolved_hints.items() + + +def get_class_property(typ: type, name: str) -> Optional[type]: + "Looks up the annotated type of a property in a class by its property name." + + for property_name, property_type in get_class_properties(typ): + if name == property_name: + return property_type + return None + + +@dataclasses.dataclass +class _ROOT: + pass + + +def get_referenced_types(typ: TypeLike, module: Optional[types.ModuleType] = None) -> Set[type]: + """ + Extracts types directly or indirectly referenced by this type. + + For example, extract `T` from `List[T]`, `Optional[T]` or `Annotated[T, ...]`, `K` and `V` from `Dict[K,V]`, + `A` and `B` from `Union[A,B]`. + + :param typ: A type or special form. + :param module: The context in which types are evaluated. + :returns: Types referenced by the given type or special form. + """ + + collector = TypeCollector() + collector.run(typ, _ROOT, module) + return collector.references + + +class TypeCollector: + """ + Collects types directly or indirectly referenced by a type. + + :param graph: The type dependency graph, linking types to types they depend on. + """ + + graph: Dict[type, Set[type]] + + @property + def references(self) -> Set[type]: + "Types collected by the type collector." + + dependencies = set() + for edges in self.graph.values(): + dependencies.update(edges) + return dependencies + + def __init__(self) -> None: + self.graph = {_ROOT: set()} + + def traverse(self, typ: type) -> None: + "Finds all dependent types of a type." + + self.run(typ, _ROOT, sys.modules[typ.__module__]) + + def traverse_all(self, types: Iterable[type]) -> None: + "Finds all dependent types of a list of types." + + for typ in types: + self.traverse(typ) + + def run( + self, + typ: TypeLike, + cls: Type[DataclassInstance], + module: Optional[types.ModuleType], + ) -> None: + """ + Extracts types indirectly referenced by this type. + + For example, extract `T` from `List[T]`, `Optional[T]` or `Annotated[T, ...]`, `K` and `V` from `Dict[K,V]`, + `A` and `B` from `Union[A,B]`. + + :param typ: A type or special form. + :param cls: A dataclass type being expanded for dependent types. + :param module: The context in which types are evaluated. + :returns: Types referenced by the given type or special form. + """ + + if typ is type(None) or typ is Any: + return + + if isinstance(typ, type): + self.graph[cls].add(typ) + + if typ in self.graph: + return + + self.graph[typ] = set() + + metadata = getattr(typ, "__metadata__", None) + if metadata is not None: + # type is Annotated[T, ...] + arg = typing.get_args(typ)[0] + return self.run(arg, cls, module) + + # type is a forward reference + if isinstance(typ, str) or isinstance(typ, typing.ForwardRef): + if module is None: + raise ValueError("missing context for evaluating types") + + evaluated_type = evaluate_type(typ, module) + return self.run(evaluated_type, cls, module) + + # type is a special form + origin = typing.get_origin(typ) + if origin in [list, dict, frozenset, set, tuple, Union]: + for arg in typing.get_args(typ): + self.run(arg, cls, module) + return + elif origin is Literal: + return + + # type is optional or a union type + if is_type_optional(typ): + return self.run(unwrap_optional_type(typ), cls, module) + if is_type_union(typ): + for union_type in unwrap_union_types(typ): + self.run(union_type, cls, module) + return + + # type is a regular type + elif is_dataclass_type(typ) or is_type_enum(typ) or isinstance(typ, type): + context = sys.modules[typ.__module__] + if is_dataclass_type(typ): + for field in dataclass_fields(typ): + self.run(field.type, typ, context) + else: + for field_name, field_type in get_resolved_hints(typ).items(): + self.run(field_type, typ, context) + return + + raise TypeError(f"expected: type-like; got: {typ}") + + +if sys.version_info >= (3, 10): + + def get_signature(fn: Callable[..., Any]) -> inspect.Signature: + "Extracts the signature of a function." + + return inspect.signature(fn, eval_str=True) + +else: + + def get_signature(fn: Callable[..., Any]) -> inspect.Signature: + "Extracts the signature of a function." + + return inspect.signature(fn) + + +def is_reserved_property(name: str) -> bool: + "True if the name stands for an internal property." + + # filter built-in and special properties + if re.match(r"^__.+__$", name): + return True + + # filter built-in special names + if name in ["_abc_impl"]: + return True + + return False + + +def create_module(name: str) -> types.ModuleType: + """ + Creates a new module dynamically at run-time. + + :param name: Fully qualified name of the new module (with dot notation). + """ + + if name in sys.modules: + raise KeyError(f"{name!r} already in sys.modules") + + spec = importlib.machinery.ModuleSpec(name, None) + module = importlib.util.module_from_spec(spec) + sys.modules[name] = module + if spec.loader is not None: + spec.loader.exec_module(module) + return module + + +if sys.version_info >= (3, 10): + + def create_data_type(class_name: str, fields: List[Tuple[str, type]]) -> type: + """ + Creates a new data-class type dynamically. + + :param class_name: The name of new data-class type. + :param fields: A list of fields (and their type) that the new data-class type is expected to have. + :returns: The newly created data-class type. + """ + + # has the `slots` parameter + return dataclasses.make_dataclass(class_name, fields, slots=True) + +else: + + def create_data_type(class_name: str, fields: List[Tuple[str, type]]) -> type: + """ + Creates a new data-class type dynamically. + + :param class_name: The name of new data-class type. + :param fields: A list of fields (and their type) that the new data-class type is expected to have. + :returns: The newly created data-class type. + """ + + cls = dataclasses.make_dataclass(class_name, fields) + + cls_dict = dict(cls.__dict__) + field_names = tuple(field.name for field in dataclasses.fields(cls)) + + cls_dict["__slots__"] = field_names + + for field_name in field_names: + cls_dict.pop(field_name, None) + cls_dict.pop("__dict__", None) + + qualname = getattr(cls, "__qualname__", None) + cls = type(cls)(cls.__name__, (), cls_dict) + if qualname is not None: + cls.__qualname__ = qualname + + return cls + + +def create_object(typ: Type[T]) -> T: + "Creates an instance of a type." + + if issubclass(typ, Exception): + # exception types need special treatment + e = typ.__new__(typ) + return typing.cast(T, e) + else: + return object.__new__(typ) + + +if sys.version_info >= (3, 9): + TypeOrGeneric = Union[type, types.GenericAlias] + +else: + TypeOrGeneric = object + + +def is_generic_instance(obj: Any, typ: TypeLike) -> bool: + """ + Returns whether an object is an instance of a generic class, a standard class or of a subclass thereof. + + This function checks the following items recursively: + * items of a list + * keys and values of a dictionary + * members of a set + * items of a tuple + * members of a union type + + :param obj: The (possibly generic container) object to check recursively. + :param typ: The expected type of the object. + """ + + if isinstance(typ, typing.ForwardRef): + fwd: typing.ForwardRef = typ + identifier = fwd.__forward_arg__ + typ = eval(identifier) + if isinstance(typ, type): + return isinstance(obj, typ) + else: + return False + + # generic types (e.g. list, dict, set, etc.) + origin_type = typing.get_origin(typ) + if origin_type is list: + if not isinstance(obj, list): + return False + (list_item_type,) = typing.get_args(typ) # unpack single tuple element + list_obj: list = obj + return all(is_generic_instance(item, list_item_type) for item in list_obj) + elif origin_type is dict: + if not isinstance(obj, dict): + return False + key_type, value_type = typing.get_args(typ) + dict_obj: dict = obj + return all( + is_generic_instance(key, key_type) and is_generic_instance(value, value_type) + for key, value in dict_obj.items() + ) + elif origin_type is set: + if not isinstance(obj, set): + return False + (set_member_type,) = typing.get_args(typ) # unpack single tuple element + set_obj: set = obj + return all(is_generic_instance(item, set_member_type) for item in set_obj) + elif origin_type is tuple: + if not isinstance(obj, tuple): + return False + return all( + is_generic_instance(item, tuple_item_type) + for tuple_item_type, item in zip( + (tuple_item_type for tuple_item_type in typing.get_args(typ)), + (item for item in obj), + ) + ) + elif origin_type is Union: + return any(is_generic_instance(obj, member_type) for member_type in typing.get_args(typ)) + elif isinstance(typ, type): + return isinstance(obj, typ) + else: + raise TypeError(f"expected `type` but got: {typ}") + + +class RecursiveChecker: + _pred: Optional[Callable[[type, Any], bool]] + + def __init__(self, pred: Callable[[type, Any], bool]) -> None: + """ + Creates a checker to verify if a predicate applies to all nested member properties of an object recursively. + + :param pred: The predicate to test on member properties. Takes a property type and a property value. + """ + + self._pred = pred + + def pred(self, typ: type, obj: Any) -> bool: + "Acts as a workaround for the type checker mypy." + + assert self._pred is not None + return self._pred(typ, obj) + + def check(self, typ: TypeLike, obj: Any) -> bool: + """ + Checks if a predicate applies to all nested member properties of an object recursively. + + :param typ: The type to recurse into. + :param obj: The object to inspect recursively. Must be an instance of the given type. + :returns: True if all member properties pass the filter predicate. + """ + + # check for well-known types + if ( + typ is type(None) + or typ is bool + or typ is int + or typ is float + or typ is str + or typ is bytes + or typ is datetime.datetime + or typ is datetime.date + or typ is datetime.time + or typ is uuid.UUID + ): + return self.pred(typing.cast(type, typ), obj) + + # generic types (e.g. list, dict, set, etc.) + origin_type = typing.get_origin(typ) + if origin_type is list: + if not isinstance(obj, list): + raise TypeError(f"expected `list` but got: {obj}") + (list_item_type,) = typing.get_args(typ) # unpack single tuple element + list_obj: list = obj + return all(self.check(list_item_type, item) for item in list_obj) + elif origin_type is dict: + if not isinstance(obj, dict): + raise TypeError(f"expected `dict` but got: {obj}") + key_type, value_type = typing.get_args(typ) + dict_obj: dict = obj + return all(self.check(value_type, item) for item in dict_obj.values()) + elif origin_type is set: + if not isinstance(obj, set): + raise TypeError(f"expected `set` but got: {obj}") + (set_member_type,) = typing.get_args(typ) # unpack single tuple element + set_obj: set = obj + return all(self.check(set_member_type, item) for item in set_obj) + elif origin_type is tuple: + if not isinstance(obj, tuple): + raise TypeError(f"expected `tuple` but got: {obj}") + return all( + self.check(tuple_item_type, item) + for tuple_item_type, item in zip( + (tuple_item_type for tuple_item_type in typing.get_args(typ)), + (item for item in obj), + ) + ) + elif origin_type is Union: + return self.pred(typ, obj) # type: ignore[arg-type] + + if not inspect.isclass(typ): + raise TypeError(f"expected `type` but got: {typ}") + + # enumeration type + if issubclass(typ, enum.Enum): + if not isinstance(obj, enum.Enum): + raise TypeError(f"expected `{typ}` but got: {obj}") + return self.pred(typ, obj) + + # class types with properties + if is_named_tuple_type(typ): + if not isinstance(obj, tuple): + raise TypeError(f"expected `NamedTuple` but got: {obj}") + return all( + self.check(field_type, getattr(obj, field_name)) + for field_name, field_type in typing.get_type_hints(typ).items() + ) + elif is_dataclass_type(typ): + if not isinstance(obj, typ): + raise TypeError(f"expected `{typ}` but got: {obj}") + resolved_hints = get_resolved_hints(typ) + return all( + self.check(resolved_hints[field.name], getattr(obj, field.name)) for field in dataclasses.fields(typ) + ) + else: + if not isinstance(obj, typ): + raise TypeError(f"expected `{typ}` but got: {obj}") + return all( + self.check(property_type, getattr(obj, property_name)) + for property_name, property_type in get_class_properties(typ) + ) + + +def check_recursive( + obj: object, + /, + *, + pred: Optional[Callable[[type, Any], bool]] = None, + type_pred: Optional[Callable[[type], bool]] = None, + value_pred: Optional[Callable[[Any], bool]] = None, +) -> bool: + """ + Checks if a predicate applies to all nested member properties of an object recursively. + + :param obj: The object to inspect recursively. + :param pred: The predicate to test on member properties. Takes a property type and a property value. + :param type_pred: Constrains the check to properties of an expected type. Properties of other types pass automatically. + :param value_pred: Verifies a condition on member property values (of an expected type). + :returns: True if all member properties pass the filter predicate(s). + """ + + if type_pred is not None and value_pred is not None: + if pred is not None: + raise TypeError("filter predicate not permitted when type and value predicates are present") + + type_p: Callable[[Type[T]], bool] = type_pred + value_p: Callable[[T], bool] = value_pred + pred = lambda typ, obj: not type_p(typ) or value_p(obj) # noqa: E731 + + elif value_pred is not None: + if pred is not None: + raise TypeError("filter predicate not permitted when value predicate is present") + + value_only_p: Callable[[T], bool] = value_pred + pred = lambda typ, obj: value_only_p(obj) # noqa: E731 + + elif type_pred is not None: + raise TypeError("value predicate required when type predicate is present") + + elif pred is None: + pred = lambda typ, obj: True # noqa: E731 + + return RecursiveChecker(pred).check(type(obj), obj) diff --git a/llama_stack/strong_typing/mapping.py b/llama_stack/strong_typing/mapping.py new file mode 100644 index 000000000..408375a9f --- /dev/null +++ b/llama_stack/strong_typing/mapping.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +import keyword +from typing import Optional + +from .auxiliary import Alias +from .inspection import get_annotation + + +def python_field_to_json_property(python_id: str, python_type: Optional[object] = None) -> str: + """ + Map a Python field identifier to a JSON property name. + + Authors may use an underscore appended at the end of a Python identifier as per PEP 8 if it clashes with a Python + keyword: e.g. `in` would become `in_` and `from` would become `from_`. Remove these suffixes when exporting to JSON. + + Authors may supply an explicit alias with the type annotation `Alias`, e.g. `Annotated[MyType, Alias("alias")]`. + """ + + if python_type is not None: + alias = get_annotation(python_type, Alias) + if alias: + return alias.name + + if python_id.endswith("_"): + id = python_id[:-1] + if keyword.iskeyword(id): + return id + + return python_id diff --git a/llama_stack/strong_typing/name.py b/llama_stack/strong_typing/name.py new file mode 100644 index 000000000..a1a2ae5f1 --- /dev/null +++ b/llama_stack/strong_typing/name.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +import typing +from typing import Any, Literal, Optional, Tuple, Union + +from .auxiliary import _auxiliary_types +from .inspection import ( + TypeLike, + is_generic_dict, + is_generic_list, + is_type_optional, + is_type_union, + unwrap_generic_dict, + unwrap_generic_list, + unwrap_optional_type, + unwrap_union_types, +) + + +class TypeFormatter: + """ + Type formatter. + + :param use_union_operator: Whether to emit union types as `X | Y` as per PEP 604. + """ + + use_union_operator: bool + + def __init__(self, use_union_operator: bool = False) -> None: + self.use_union_operator = use_union_operator + + def union_to_str(self, data_type_args: Tuple[TypeLike, ...]) -> str: + if self.use_union_operator: + return " | ".join(self.python_type_to_str(t) for t in data_type_args) + else: + if len(data_type_args) == 2 and type(None) in data_type_args: + # Optional[T] is represented as Union[T, None] + origin_name = "Optional" + data_type_args = tuple(t for t in data_type_args if t is not type(None)) + else: + origin_name = "Union" + + args = ", ".join(self.python_type_to_str(t) for t in data_type_args) + return f"{origin_name}[{args}]" + + def plain_type_to_str(self, data_type: TypeLike) -> str: + "Returns the string representation of a Python type without metadata." + + # return forward references as the annotation string + if isinstance(data_type, typing.ForwardRef): + fwd: typing.ForwardRef = data_type + return fwd.__forward_arg__ + elif isinstance(data_type, str): + return data_type + + origin = typing.get_origin(data_type) + if origin is not None: + data_type_args = typing.get_args(data_type) + + if origin is dict: # Dict[T] + origin_name = "Dict" + elif origin is list: # List[T] + origin_name = "List" + elif origin is set: # Set[T] + origin_name = "Set" + elif origin is Union: + return self.union_to_str(data_type_args) + elif origin is Literal: + args = ", ".join(repr(arg) for arg in data_type_args) + return f"Literal[{args}]" + else: + origin_name = origin.__name__ + + args = ", ".join(self.python_type_to_str(t) for t in data_type_args) + return f"{origin_name}[{args}]" + + return data_type.__name__ + + def python_type_to_str(self, data_type: TypeLike) -> str: + "Returns the string representation of a Python type." + + if data_type is type(None): + return "None" + + # use compact name for alias types + name = _auxiliary_types.get(data_type) + if name is not None: + return name + + metadata = getattr(data_type, "__metadata__", None) + if metadata is not None: + # type is Annotated[T, ...] + metatuple: Tuple[Any, ...] = metadata + arg = typing.get_args(data_type)[0] + + # check for auxiliary types with user-defined annotations + metaset = set(metatuple) + for auxiliary_type, auxiliary_name in _auxiliary_types.items(): + auxiliary_arg = typing.get_args(auxiliary_type)[0] + if arg is not auxiliary_arg: + continue + + auxiliary_metatuple: Optional[Tuple[Any, ...]] = getattr(auxiliary_type, "__metadata__", None) + if auxiliary_metatuple is None: + continue + + if metaset.issuperset(auxiliary_metatuple): + # type is an auxiliary type with extra annotations + auxiliary_args = ", ".join(repr(m) for m in metatuple if m not in auxiliary_metatuple) + return f"Annotated[{auxiliary_name}, {auxiliary_args}]" + + # type is an annotated type + args = ", ".join(repr(m) for m in metatuple) + return f"Annotated[{self.plain_type_to_str(arg)}, {args}]" + else: + # type is a regular type + return self.plain_type_to_str(data_type) + + +def python_type_to_str(data_type: TypeLike, use_union_operator: bool = False) -> str: + """ + Returns the string representation of a Python type. + + :param use_union_operator: Whether to emit union types as `X | Y` as per PEP 604. + """ + + fmt = TypeFormatter(use_union_operator) + return fmt.python_type_to_str(data_type) + + +def python_type_to_name(data_type: TypeLike, force: bool = False) -> str: + """ + Returns the short name of a Python type. + + :param force: Whether to produce a name for composite types such as generics. + """ + + # use compact name for alias types + name = _auxiliary_types.get(data_type) + if name is not None: + return name + + # unwrap annotated types + metadata = getattr(data_type, "__metadata__", None) + if metadata is not None: + # type is Annotated[T, ...] + arg = typing.get_args(data_type)[0] + return python_type_to_name(arg) + + if force: + # generic types + if is_type_optional(data_type, strict=True): + inner_name = python_type_to_name(unwrap_optional_type(data_type)) + return f"Optional__{inner_name}" + elif is_generic_list(data_type): + item_name = python_type_to_name(unwrap_generic_list(data_type)) + return f"List__{item_name}" + elif is_generic_dict(data_type): + key_type, value_type = unwrap_generic_dict(data_type) + key_name = python_type_to_name(key_type) + value_name = python_type_to_name(value_type) + return f"Dict__{key_name}__{value_name}" + elif is_type_union(data_type): + member_types = unwrap_union_types(data_type) + member_names = "__".join(python_type_to_name(member_type) for member_type in member_types) + return f"Union__{member_names}" + + # named system or user-defined type + if hasattr(data_type, "__name__") and not typing.get_args(data_type): + return data_type.__name__ + + raise TypeError(f"cannot assign a simple name to type: {data_type}") diff --git a/llama_stack/strong_typing/py.typed b/llama_stack/strong_typing/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/llama_stack/strong_typing/schema.py b/llama_stack/strong_typing/schema.py new file mode 100644 index 000000000..ddff7cf82 --- /dev/null +++ b/llama_stack/strong_typing/schema.py @@ -0,0 +1,752 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +import dataclasses +import datetime +import decimal +import enum +import functools +import inspect +import json +import typing +import uuid +from copy import deepcopy +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +import jsonschema +from typing_extensions import Annotated + +from . import docstring +from .auxiliary import ( + Alias, + IntegerRange, + MaxLength, + MinLength, + Precision, + get_auxiliary_format, +) +from .core import JsonArray, JsonObject, JsonType, Schema, StrictJsonType +from .inspection import ( + TypeLike, + enum_value_types, + get_annotation, + get_class_properties, + is_type_enum, + is_type_like, + is_type_optional, + unwrap_optional_type, +) +from .name import python_type_to_name +from .serialization import object_to_json + +# determines the maximum number of distinct enum members up to which a Dict[EnumType, Any] is converted into a JSON +# schema with explicitly listed properties (rather than employing a pattern constraint on property names) +OBJECT_ENUM_EXPANSION_LIMIT = 4 + + +T = TypeVar("T") + + +def get_class_docstrings(data_type: type) -> Tuple[Optional[str], Optional[str]]: + docstr = docstring.parse_type(data_type) + + # check if class has a doc-string other than the auto-generated string assigned by @dataclass + if docstring.has_default_docstring(data_type): + return None, None + + return docstr.short_description, docstr.long_description + + +def get_class_property_docstrings( + data_type: type, transform_fun: Optional[Callable[[type, str, str], str]] = None +) -> Dict[str, str]: + """ + Extracts the documentation strings associated with the properties of a composite type. + + :param data_type: The object whose properties to iterate over. + :param transform_fun: An optional function that maps a property documentation string to a custom tailored string. + :returns: A dictionary mapping property names to descriptions. + """ + + result = {} + for base in inspect.getmro(data_type): + docstr = docstring.parse_type(base) + for param in docstr.params.values(): + if param.name in result: + continue + + if transform_fun: + description = transform_fun(data_type, param.name, param.description) + else: + description = param.description + + result[param.name] = description + return result + + +def docstring_to_schema(data_type: type) -> Schema: + short_description, long_description = get_class_docstrings(data_type) + schema: Schema = {} + + description = "\n".join(filter(None, [short_description, long_description])) + if description: + schema["description"] = description + return schema + + +def id_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> str: + "Extracts the name of a possibly forward-referenced type." + + if isinstance(data_type, typing.ForwardRef): + forward_type: typing.ForwardRef = data_type + return forward_type.__forward_arg__ + elif isinstance(data_type, str): + return data_type + else: + return data_type.__name__ + + +def type_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> Tuple[str, type]: + "Creates a type from a forward reference." + + if isinstance(data_type, typing.ForwardRef): + forward_type: typing.ForwardRef = data_type + true_type = eval(forward_type.__forward_code__) + return forward_type.__forward_arg__, true_type + elif isinstance(data_type, str): + true_type = eval(data_type) + return data_type, true_type + else: + return data_type.__name__, data_type + + +@dataclasses.dataclass +class TypeCatalogEntry: + schema: Optional[Schema] + identifier: str + examples: Optional[JsonType] = None + + +class TypeCatalog: + "Maintains an association of well-known Python types to their JSON schema." + + _by_type: Dict[TypeLike, TypeCatalogEntry] + _by_name: Dict[str, TypeCatalogEntry] + + def __init__(self) -> None: + self._by_type = {} + self._by_name = {} + + def __contains__(self, data_type: TypeLike) -> bool: + if isinstance(data_type, typing.ForwardRef): + fwd: typing.ForwardRef = data_type + name = fwd.__forward_arg__ + return name in self._by_name + else: + return data_type in self._by_type + + def add( + self, + data_type: TypeLike, + schema: Optional[Schema], + identifier: str, + examples: Optional[List[JsonType]] = None, + ) -> None: + if isinstance(data_type, typing.ForwardRef): + raise TypeError("forward references cannot be used to register a type") + + if data_type in self._by_type: + raise ValueError(f"type {data_type} is already registered in the catalog") + + entry = TypeCatalogEntry(schema, identifier, examples) + self._by_type[data_type] = entry + self._by_name[identifier] = entry + + def get(self, data_type: TypeLike) -> TypeCatalogEntry: + if isinstance(data_type, typing.ForwardRef): + fwd: typing.ForwardRef = data_type + name = fwd.__forward_arg__ + return self._by_name[name] + else: + return self._by_type[data_type] + + +@dataclasses.dataclass +class SchemaOptions: + definitions_path: str = "#/definitions/" + use_descriptions: bool = True + use_examples: bool = True + property_description_fun: Optional[Callable[[type, str, str], str]] = None + + +class JsonSchemaGenerator: + "Creates a JSON schema with user-defined type definitions." + + type_catalog: ClassVar[TypeCatalog] = TypeCatalog() + types_used: Dict[str, TypeLike] + options: SchemaOptions + + def __init__(self, options: Optional[SchemaOptions] = None): + if options is None: + self.options = SchemaOptions() + else: + self.options = options + self.types_used = {} + + @functools.singledispatchmethod + def _metadata_to_schema(self, arg: object) -> Schema: + # unrecognized annotation + return {} + + @_metadata_to_schema.register + def _(self, arg: IntegerRange) -> Schema: + return {"minimum": arg.minimum, "maximum": arg.maximum} + + @_metadata_to_schema.register + def _(self, arg: Precision) -> Schema: + return { + "multipleOf": 10 ** (-arg.decimal_digits), + "exclusiveMinimum": -(10**arg.integer_digits), + "exclusiveMaximum": (10**arg.integer_digits), + } + + @_metadata_to_schema.register + def _(self, arg: MinLength) -> Schema: + return {"minLength": arg.value} + + @_metadata_to_schema.register + def _(self, arg: MaxLength) -> Schema: + return {"maxLength": arg.value} + + def _with_metadata(self, type_schema: Schema, metadata: Optional[Tuple[Any, ...]]) -> Schema: + if metadata: + for m in metadata: + type_schema.update(self._metadata_to_schema(m)) + return type_schema + + def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: Optional[dict] = None) -> Optional[Schema]: + """ + Returns the JSON schema associated with a simple, unrestricted type. + + :returns: The schema for a simple type, or `None`. + """ + + if typ is type(None): + return {"type": "null"} + elif typ is bool: + return {"type": "boolean"} + elif typ is int: + return {"type": "integer"} + elif typ is float: + return {"type": "number"} + elif typ is str: + if json_schema_extra and "contentEncoding" in json_schema_extra: + return { + "type": "string", + "contentEncoding": json_schema_extra["contentEncoding"], + } + return {"type": "string"} + elif typ is bytes: + return {"type": "string", "contentEncoding": "base64"} + elif typ is datetime.datetime: + # 2018-11-13T20:20:39+00:00 + return { + "type": "string", + "format": "date-time", + } + elif typ is datetime.date: + # 2018-11-13 + return {"type": "string", "format": "date"} + elif typ is datetime.time: + # 20:20:39+00:00 + return {"type": "string", "format": "time"} + elif typ is decimal.Decimal: + return {"type": "number"} + elif typ is uuid.UUID: + # f81d4fae-7dec-11d0-a765-00a0c91e6bf6 + return {"type": "string", "format": "uuid"} + elif typ is Any: + return { + "oneOf": [ + {"type": "null"}, + {"type": "boolean"}, + {"type": "number"}, + {"type": "string"}, + {"type": "array"}, + {"type": "object"}, + ] + } + elif typ is JsonObject: + return {"type": "object"} + elif typ is JsonArray: + return {"type": "array"} + else: + # not a simple type + return None + + def type_to_schema( + self, + data_type: TypeLike, + force_expand: bool = False, + json_schema_extra: Optional[dict] = None, + ) -> Schema: + """ + Returns the JSON schema associated with a type. + + :param data_type: The Python type whose JSON schema to return. + :param force_expand: Forces a JSON schema to be returned even if the type is registered in the catalog of known types. + :returns: The JSON schema associated with the type. + """ + + # short-circuit for common simple types + schema = self._simple_type_to_schema(data_type, json_schema_extra) + if schema is not None: + return schema + + # types registered in the type catalog of well-known types + type_catalog = JsonSchemaGenerator.type_catalog + if not force_expand and data_type in type_catalog: + # user-defined type + identifier = type_catalog.get(data_type).identifier + self.types_used.setdefault(identifier, data_type) + return {"$ref": f"{self.options.definitions_path}{identifier}"} + + # unwrap annotated types + metadata = getattr(data_type, "__metadata__", None) + if metadata is not None: + # type is Annotated[T, ...] + typ = typing.get_args(data_type)[0] + schema = self._simple_type_to_schema(typ) + if schema is not None: + # recognize well-known auxiliary types + fmt = get_auxiliary_format(data_type) + if fmt is not None: + schema.update({"format": fmt}) + return schema + else: + return self._with_metadata(schema, metadata) + + else: + # type is a regular type + typ = data_type + + if isinstance(typ, typing.ForwardRef) or isinstance(typ, str): + if force_expand: + identifier, true_type = type_from_ref(typ) + return self.type_to_schema(true_type, force_expand=True) + else: + try: + identifier, true_type = type_from_ref(typ) + self.types_used[identifier] = true_type + except NameError: + identifier = id_from_ref(typ) + + return {"$ref": f"{self.options.definitions_path}{identifier}"} + + if is_type_enum(typ): + enum_type: Type[enum.Enum] = typ + value_types = enum_value_types(enum_type) + if len(value_types) != 1: + raise ValueError( + f"enumerations must have a consistent member value type but several types found: {value_types}" + ) + enum_value_type = value_types.pop() + + enum_schema: Schema + if enum_value_type is bool or enum_value_type is int or enum_value_type is float or enum_value_type is str: + if enum_value_type is bool: + enum_schema_type = "boolean" + elif enum_value_type is int: + enum_schema_type = "integer" + elif enum_value_type is float: + enum_schema_type = "number" + elif enum_value_type is str: + enum_schema_type = "string" + + enum_schema = { + "type": enum_schema_type, + "enum": [object_to_json(e.value) for e in enum_type], + } + if self.options.use_descriptions: + enum_schema.update(docstring_to_schema(typ)) + return enum_schema + else: + enum_schema = self.type_to_schema(enum_value_type) + if self.options.use_descriptions: + enum_schema.update(docstring_to_schema(typ)) + return enum_schema + + origin_type = typing.get_origin(typ) + if origin_type is list: + (list_type,) = typing.get_args(typ) # unpack single tuple element + return {"type": "array", "items": self.type_to_schema(list_type)} + elif origin_type is dict: + key_type, value_type = typing.get_args(typ) + if not (key_type is str or key_type is int or is_type_enum(key_type)): + raise ValueError("`dict` with key type not coercible to `str` is not supported") + + dict_schema: Schema + value_schema = self.type_to_schema(value_type) + if is_type_enum(key_type): + enum_values = [str(e.value) for e in key_type] + if len(enum_values) > OBJECT_ENUM_EXPANSION_LIMIT: + dict_schema = { + "propertyNames": {"pattern": "^(" + "|".join(enum_values) + ")$"}, + "additionalProperties": value_schema, + } + else: + dict_schema = { + "properties": {value: value_schema for value in enum_values}, + "additionalProperties": False, + } + else: + dict_schema = {"additionalProperties": value_schema} + + schema = {"type": "object"} + schema.update(dict_schema) + return schema + elif origin_type is set: + (set_type,) = typing.get_args(typ) # unpack single tuple element + return { + "type": "array", + "items": self.type_to_schema(set_type), + "uniqueItems": True, + } + elif origin_type is tuple: + args = typing.get_args(typ) + return { + "type": "array", + "minItems": len(args), + "maxItems": len(args), + "prefixItems": [self.type_to_schema(member_type) for member_type in args], + } + elif origin_type is Union: + discriminator = None + if typing.get_origin(data_type) is Annotated: + discriminator = typing.get_args(data_type)[1].discriminator + ret = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]} + if discriminator: + # for each union type, we need to read the value of the discriminator + mapping = {} + for union_type in typing.get_args(typ): + props = self.type_to_schema(union_type, force_expand=True)["properties"] + mapping[props[discriminator]["default"]] = self.type_to_schema(union_type)["$ref"] + + ret["discriminator"] = { + "propertyName": discriminator, + "mapping": mapping, + } + return ret + elif origin_type is Literal: + (literal_value,) = typing.get_args(typ) # unpack value of literal type + schema = self.type_to_schema(type(literal_value)) + schema["const"] = literal_value + return schema + elif origin_type is type: + (concrete_type,) = typing.get_args(typ) # unpack single tuple element + return {"const": self.type_to_schema(concrete_type, force_expand=True)} + + # dictionary of class attributes + members = dict(inspect.getmembers(typ, lambda a: not inspect.isroutine(a))) + + property_docstrings = get_class_property_docstrings(typ, self.options.property_description_fun) + properties: Dict[str, Schema] = {} + required: List[str] = [] + for property_name, property_type in get_class_properties(typ): + # rename property if an alias name is specified + alias = get_annotation(property_type, Alias) + if alias: + output_name = alias.name + else: + output_name = property_name + + defaults = {} + json_schema_extra = None + if "model_fields" in members: + f = members["model_fields"] + defaults = {k: finfo.default for k, finfo in f.items()} + json_schema_extra = f.get(output_name, None).json_schema_extra + + if is_type_optional(property_type): + optional_type: type = unwrap_optional_type(property_type) + property_def = self.type_to_schema(optional_type, json_schema_extra=json_schema_extra) + else: + property_def = self.type_to_schema(property_type, json_schema_extra=json_schema_extra) + required.append(output_name) + + # check if attribute has a default value initializer + if defaults.get(property_name) is not None: + def_value = defaults[property_name] + # check if value can be directly represented in JSON + if isinstance( + def_value, + ( + bool, + int, + float, + str, + enum.Enum, + datetime.datetime, + datetime.date, + datetime.time, + ), + ): + property_def["default"] = object_to_json(def_value) + + # add property docstring if available + property_doc = property_docstrings.get(property_name) + if property_doc: + # print(output_name, property_doc) + property_def.pop("title", None) + property_def["description"] = property_doc + + properties[output_name] = property_def + + schema = {"type": "object"} + if len(properties) > 0: + schema["properties"] = typing.cast(JsonType, properties) + schema["additionalProperties"] = False + if len(required) > 0: + schema["required"] = typing.cast(JsonType, required) + if self.options.use_descriptions: + schema.update(docstring_to_schema(typ)) + return schema + + def _type_to_schema_with_lookup(self, data_type: TypeLike) -> Schema: + """ + Returns the JSON schema associated with a type that may be registered in the catalog of known types. + + :param data_type: The type whose JSON schema we seek. + :returns: The JSON schema associated with the type. + """ + + entry = JsonSchemaGenerator.type_catalog.get(data_type) + if entry.schema is None: + type_schema = self.type_to_schema(data_type, force_expand=True) + else: + type_schema = deepcopy(entry.schema) + + # add descriptive text (if present) + if self.options.use_descriptions: + if isinstance(data_type, type) and not isinstance(data_type, typing.ForwardRef): + type_schema.update(docstring_to_schema(data_type)) + + # add example (if present) + if self.options.use_examples and entry.examples: + type_schema["examples"] = entry.examples + + return type_schema + + def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> Tuple[Schema, Dict[str, Schema]]: + """ + Returns the JSON schema associated with a type and any nested types. + + :param data_type: The type whose JSON schema to return. + :param force_expand: True if a full JSON schema is to be returned even for well-known types; false if a schema + reference is to be used for well-known types. + :returns: A tuple of the JSON schema, and a mapping between nested type names and their corresponding schema. + """ + + if not is_type_like(data_type): + raise TypeError(f"expected a type-like object but got: {data_type}") + + self.types_used = {} + try: + type_schema = self.type_to_schema(data_type, force_expand=force_expand) + + types_defined: Dict[str, Schema] = {} + while len(self.types_used) > len(types_defined): + # make a snapshot copy; original collection is going to be modified + types_undefined = { + sub_name: sub_type + for sub_name, sub_type in self.types_used.items() + if sub_name not in types_defined + } + + # expand undefined types, which may lead to additional types to be defined + for sub_name, sub_type in types_undefined.items(): + types_defined[sub_name] = self._type_to_schema_with_lookup(sub_type) + + type_definitions = dict(sorted(types_defined.items())) + finally: + self.types_used = {} + + return type_schema, type_definitions + + +class Validator(enum.Enum): + "Defines constants for JSON schema standards." + + Draft7 = jsonschema.Draft7Validator + Draft201909 = jsonschema.Draft201909Validator + Draft202012 = jsonschema.Draft202012Validator + Latest = jsonschema.Draft202012Validator + + +def classdef_to_schema( + data_type: TypeLike, + options: Optional[SchemaOptions] = None, + validator: Validator = Validator.Latest, +) -> Schema: + """ + Returns the JSON schema corresponding to the given type. + + :param data_type: The Python type used to generate the JSON schema + :returns: A JSON object that you can serialize to a JSON string with json.dump or json.dumps + :raises TypeError: Indicates that the generated JSON schema does not validate against the desired meta-schema. + """ + + # short-circuit with an error message when passing invalid data + if not is_type_like(data_type): + raise TypeError(f"expected a type-like object but got: {data_type}") + + generator = JsonSchemaGenerator(options) + type_schema, type_definitions = generator.classdef_to_schema(data_type) + + class_schema: Schema = {} + if type_definitions: + class_schema["definitions"] = typing.cast(JsonType, type_definitions) + class_schema.update(type_schema) + + validator_id = validator.value.META_SCHEMA["$id"] + try: + validator.value.check_schema(class_schema) + except jsonschema.exceptions.SchemaError: + raise TypeError(f"schema does not validate against meta-schema <{validator_id}>") + + schema = {"$schema": validator_id} + schema.update(class_schema) + return schema + + +def validate_object(data_type: TypeLike, json_dict: JsonType) -> None: + """ + Validates if the JSON dictionary object conforms to the expected type. + + :param data_type: The type to match against. + :param json_dict: A JSON object obtained with `json.load` or `json.loads`. + :raises jsonschema.exceptions.ValidationError: Indicates that the JSON object cannot represent the type. + """ + + schema_dict = classdef_to_schema(data_type) + jsonschema.validate(json_dict, schema_dict, format_checker=jsonschema.FormatChecker()) + + +def print_schema(data_type: type) -> None: + """Pretty-prints the JSON schema corresponding to the type.""" + + s = classdef_to_schema(data_type) + print(json.dumps(s, indent=4)) + + +def get_schema_identifier(data_type: type) -> Optional[str]: + if data_type in JsonSchemaGenerator.type_catalog: + return JsonSchemaGenerator.type_catalog.get(data_type).identifier + else: + return None + + +def register_schema( + data_type: T, + schema: Optional[Schema] = None, + name: Optional[str] = None, + examples: Optional[List[JsonType]] = None, +) -> T: + """ + Associates a type with a JSON schema definition. + + :param data_type: The type to associate with a JSON schema. + :param schema: The schema to associate the type with. Derived automatically if omitted. + :param name: The name used for looking uo the type. Determined automatically if omitted. + :returns: The input type. + """ + + JsonSchemaGenerator.type_catalog.add( + data_type, + schema, + name if name is not None else python_type_to_name(data_type), + examples, + ) + return data_type + + +@overload +def json_schema_type(cls: Type[T], /) -> Type[T]: ... + + +@overload +def json_schema_type(cls: None, *, schema: Optional[Schema] = None) -> Callable[[Type[T]], Type[T]]: ... + + +def json_schema_type( + cls: Optional[Type[T]] = None, + *, + schema: Optional[Schema] = None, + examples: Optional[List[JsonType]] = None, +) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: + """Decorator to add user-defined schema definition to a class.""" + + def wrap(cls: Type[T]) -> Type[T]: + return register_schema(cls, schema, examples=examples) + + # see if decorator is used as @json_schema_type or @json_schema_type() + if cls is None: + # called with parentheses + return wrap + else: + # called as @json_schema_type without parentheses + return wrap(cls) + + +register_schema(JsonObject, name="JsonObject") +register_schema(JsonArray, name="JsonArray") + +register_schema( + JsonType, + name="JsonType", + examples=[ + { + "property1": None, + "property2": True, + "property3": 64, + "property4": "string", + "property5": ["item"], + "property6": {"key": "value"}, + } + ], +) +register_schema( + StrictJsonType, + name="StrictJsonType", + examples=[ + { + "property1": True, + "property2": 64, + "property3": "string", + "property4": ["item"], + "property5": {"key": "value"}, + } + ], +) diff --git a/llama_stack/strong_typing/serialization.py b/llama_stack/strong_typing/serialization.py new file mode 100644 index 000000000..c00a0aad5 --- /dev/null +++ b/llama_stack/strong_typing/serialization.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +import inspect +import json +import sys +from types import ModuleType +from typing import Any, Optional, TextIO, TypeVar + +from .core import JsonType +from .deserializer import create_deserializer +from .inspection import TypeLike +from .serializer import create_serializer + +T = TypeVar("T") + + +def object_to_json(obj: Any) -> JsonType: + """ + Converts a Python object to a representation that can be exported to JSON. + + * Fundamental types (e.g. numeric types) are written as is. + * Date and time types are serialized in the ISO 8601 format with time zone. + * A byte array is written as a string with Base64 encoding. + * UUIDs are written as a UUID string. + * Enumerations are written as their value. + * Containers (e.g. `list`, `dict`, `set`, `tuple`) are exported recursively. + * Objects with properties (including data class types) are converted to a dictionaries of key-value pairs. + """ + + typ: type = type(obj) + generator = create_serializer(typ) + return generator.generate(obj) + + +def json_to_object(typ: TypeLike, data: JsonType, *, context: Optional[ModuleType] = None) -> object: + """ + Creates an object from a representation that has been de-serialized from JSON. + + When de-serializing a JSON object into a Python object, the following transformations are applied: + + * Fundamental types are parsed as `bool`, `int`, `float` or `str`. + * Date and time types are parsed from the ISO 8601 format with time zone into the corresponding Python type + `datetime`, `date` or `time` + * A byte array is read from a string with Base64 encoding into a `bytes` instance. + * UUIDs are extracted from a UUID string into a `uuid.UUID` instance. + * Enumerations are instantiated with a lookup on enumeration value. + * Containers (e.g. `list`, `dict`, `set`, `tuple`) are parsed recursively. + * Complex objects with properties (including data class types) are populated from dictionaries of key-value pairs + using reflection (enumerating type annotations). + + :raises TypeError: A de-serializing engine cannot be constructed for the input type. + :raises JsonKeyError: Deserialization for a class or union type has failed because a matching member was not found. + :raises JsonTypeError: Deserialization for data has failed due to a type mismatch. + """ + + # use caller context for evaluating types if no context is supplied + if context is None: + this_frame = inspect.currentframe() + if this_frame is not None: + caller_frame = this_frame.f_back + del this_frame + + if caller_frame is not None: + try: + context = sys.modules[caller_frame.f_globals["__name__"]] + finally: + del caller_frame + + parser = create_deserializer(typ, context) + return parser.parse(data) + + +def json_dump_string(json_object: JsonType) -> str: + "Dump an object as a JSON string with a compact representation." + + return json.dumps(json_object, ensure_ascii=False, check_circular=False, separators=(",", ":")) + + +def json_dump(json_object: JsonType, file: TextIO) -> None: + json.dump( + json_object, + file, + ensure_ascii=False, + check_circular=False, + separators=(",", ":"), + ) + file.write("\n") diff --git a/llama_stack/strong_typing/serializer.py b/llama_stack/strong_typing/serializer.py new file mode 100644 index 000000000..5e93e4c4d --- /dev/null +++ b/llama_stack/strong_typing/serializer.py @@ -0,0 +1,497 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +import abc +import base64 +import datetime +import enum +import functools +import inspect +import ipaddress +import sys +import typing +import uuid +from types import FunctionType, MethodType, ModuleType +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Literal, + NamedTuple, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from .core import JsonType +from .exception import JsonTypeError, JsonValueError +from .inspection import ( + TypeLike, + enum_value_types, + evaluate_type, + get_class_properties, + get_resolved_hints, + is_dataclass_type, + is_named_tuple_type, + is_reserved_property, + is_type_annotated, + is_type_enum, + unwrap_annotated_type, +) +from .mapping import python_field_to_json_property + +T = TypeVar("T") + + +class Serializer(abc.ABC, Generic[T]): + @abc.abstractmethod + def generate(self, data: T) -> JsonType: ... + + +class NoneSerializer(Serializer[None]): + def generate(self, data: None) -> None: + # can be directly represented in JSON + return None + + +class BoolSerializer(Serializer[bool]): + def generate(self, data: bool) -> bool: + # can be directly represented in JSON + return data + + +class IntSerializer(Serializer[int]): + def generate(self, data: int) -> int: + # can be directly represented in JSON + return data + + +class FloatSerializer(Serializer[float]): + def generate(self, data: float) -> float: + # can be directly represented in JSON + return data + + +class StringSerializer(Serializer[str]): + def generate(self, data: str) -> str: + # can be directly represented in JSON + return data + + +class BytesSerializer(Serializer[bytes]): + def generate(self, data: bytes) -> str: + return base64.b64encode(data).decode("ascii") + + +class DateTimeSerializer(Serializer[datetime.datetime]): + def generate(self, obj: datetime.datetime) -> str: + if obj.tzinfo is None: + raise JsonValueError(f"timestamp lacks explicit time zone designator: {obj}") + fmt = obj.isoformat() + if fmt.endswith("+00:00"): + fmt = f"{fmt[:-6]}Z" # Python's isoformat() does not support military time zones like "Zulu" for UTC + return fmt + + +class DateSerializer(Serializer[datetime.date]): + def generate(self, obj: datetime.date) -> str: + return obj.isoformat() + + +class TimeSerializer(Serializer[datetime.time]): + def generate(self, obj: datetime.time) -> str: + return obj.isoformat() + + +class UUIDSerializer(Serializer[uuid.UUID]): + def generate(self, obj: uuid.UUID) -> str: + return str(obj) + + +class IPv4Serializer(Serializer[ipaddress.IPv4Address]): + def generate(self, obj: ipaddress.IPv4Address) -> str: + return str(obj) + + +class IPv6Serializer(Serializer[ipaddress.IPv6Address]): + def generate(self, obj: ipaddress.IPv6Address) -> str: + return str(obj) + + +class EnumSerializer(Serializer[enum.Enum]): + def generate(self, obj: enum.Enum) -> Union[int, str]: + return obj.value + + +class UntypedListSerializer(Serializer[list]): + def generate(self, obj: list) -> List[JsonType]: + return [object_to_json(item) for item in obj] + + +class UntypedDictSerializer(Serializer[dict]): + def generate(self, obj: dict) -> Dict[str, JsonType]: + if obj and isinstance(next(iter(obj.keys())), enum.Enum): + iterator = ((key.value, object_to_json(value)) for key, value in obj.items()) + else: + iterator = ((str(key), object_to_json(value)) for key, value in obj.items()) + return dict(iterator) + + +class UntypedSetSerializer(Serializer[set]): + def generate(self, obj: set) -> List[JsonType]: + return [object_to_json(item) for item in obj] + + +class UntypedTupleSerializer(Serializer[tuple]): + def generate(self, obj: tuple) -> List[JsonType]: + return [object_to_json(item) for item in obj] + + +class TypedCollectionSerializer(Serializer, Generic[T]): + generator: Serializer[T] + + def __init__(self, item_type: Type[T], context: Optional[ModuleType]) -> None: + self.generator = _get_serializer(item_type, context) + + +class TypedListSerializer(TypedCollectionSerializer[T]): + def generate(self, obj: List[T]) -> List[JsonType]: + return [self.generator.generate(item) for item in obj] + + +class TypedStringDictSerializer(TypedCollectionSerializer[T]): + def __init__(self, value_type: Type[T], context: Optional[ModuleType]) -> None: + super().__init__(value_type, context) + + def generate(self, obj: Dict[str, T]) -> Dict[str, JsonType]: + return {key: self.generator.generate(value) for key, value in obj.items()} + + +class TypedEnumDictSerializer(TypedCollectionSerializer[T]): + def __init__( + self, + key_type: Type[enum.Enum], + value_type: Type[T], + context: Optional[ModuleType], + ) -> None: + super().__init__(value_type, context) + + value_types = enum_value_types(key_type) + if len(value_types) != 1: + raise JsonTypeError( + f"invalid key type, enumerations must have a consistent member value type but several types found: {value_types}" + ) + + value_type = value_types.pop() + if value_type is not str: + raise JsonTypeError("invalid enumeration key type, expected `enum.Enum` with string values") + + def generate(self, obj: Dict[enum.Enum, T]) -> Dict[str, JsonType]: + return {key.value: self.generator.generate(value) for key, value in obj.items()} + + +class TypedSetSerializer(TypedCollectionSerializer[T]): + def generate(self, obj: Set[T]) -> JsonType: + return [self.generator.generate(item) for item in obj] + + +class TypedTupleSerializer(Serializer[tuple]): + item_generators: Tuple[Serializer, ...] + + def __init__(self, item_types: Tuple[type, ...], context: Optional[ModuleType]) -> None: + self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types) + + def generate(self, obj: tuple) -> List[JsonType]: + return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj)] + + +class CustomSerializer(Serializer): + converter: Callable[[object], JsonType] + + def __init__(self, converter: Callable[[object], JsonType]) -> None: + self.converter = converter + + def generate(self, obj: object) -> JsonType: + return self.converter(obj) + + +class FieldSerializer(Generic[T]): + """ + Serializes a Python object field into a JSON property. + + :param field_name: The name of the field in a Python class to read data from. + :param property_name: The name of the JSON property to write to a JSON `object`. + :param generator: A compatible serializer that can handle the field's type. + """ + + field_name: str + property_name: str + generator: Serializer + + def __init__(self, field_name: str, property_name: str, generator: Serializer[T]) -> None: + self.field_name = field_name + self.property_name = property_name + self.generator = generator + + def generate_field(self, obj: object, object_dict: Dict[str, JsonType]) -> None: + value = getattr(obj, self.field_name) + if value is not None: + object_dict[self.property_name] = self.generator.generate(value) + + +class TypedClassSerializer(Serializer[T]): + property_generators: List[FieldSerializer] + + def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None: + self.property_generators = [ + FieldSerializer( + field_name, + python_field_to_json_property(field_name, field_type), + _get_serializer(field_type, context), + ) + for field_name, field_type in get_class_properties(class_type) + ] + + def generate(self, obj: T) -> Dict[str, JsonType]: + object_dict: Dict[str, JsonType] = {} + for property_generator in self.property_generators: + property_generator.generate_field(obj, object_dict) + + return object_dict + + +class TypedNamedTupleSerializer(TypedClassSerializer[NamedTuple]): + def __init__(self, class_type: Type[NamedTuple], context: Optional[ModuleType]) -> None: + super().__init__(class_type, context) + + +class DataclassSerializer(TypedClassSerializer[T]): + def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None: + super().__init__(class_type, context) + + +class UnionSerializer(Serializer): + def generate(self, obj: Any) -> JsonType: + return object_to_json(obj) + + +class LiteralSerializer(Serializer): + generator: Serializer + + def __init__(self, values: Tuple[Any, ...], context: Optional[ModuleType]) -> None: + literal_type_tuple = tuple(type(value) for value in values) + literal_type_set = set(literal_type_tuple) + if len(literal_type_set) != 1: + value_names = ", ".join(repr(value) for value in values) + raise TypeError( + f"type `Literal[{value_names}]` expects consistent literal value types but got: {literal_type_tuple}" + ) + + literal_type = literal_type_set.pop() + self.generator = _get_serializer(literal_type, context) + + def generate(self, obj: Any) -> JsonType: + return self.generator.generate(obj) + + +class UntypedNamedTupleSerializer(Serializer): + fields: Dict[str, str] + + def __init__(self, class_type: Type[NamedTuple]) -> None: + # named tuples are also instances of tuple + self.fields = {} + field_names: Tuple[str, ...] = class_type._fields + for field_name in field_names: + self.fields[field_name] = python_field_to_json_property(field_name) + + def generate(self, obj: NamedTuple) -> JsonType: + object_dict = {} + for field_name, property_name in self.fields.items(): + value = getattr(obj, field_name) + object_dict[property_name] = object_to_json(value) + + return object_dict + + +class UntypedClassSerializer(Serializer): + def generate(self, obj: object) -> JsonType: + # iterate over object attributes to get a standard representation + object_dict = {} + for name in dir(obj): + if is_reserved_property(name): + continue + + value = getattr(obj, name) + if value is None: + continue + + # filter instance methods + if inspect.ismethod(value): + continue + + object_dict[python_field_to_json_property(name)] = object_to_json(value) + + return object_dict + + +def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Serializer: + """ + Creates a serializer engine to produce an object that can be directly converted into a JSON string. + + When serializing a Python object into a JSON object, the following transformations are applied: + + * Fundamental types (`bool`, `int`, `float` or `str`) are returned as-is. + * Date and time types (`datetime`, `date` or `time`) produce an ISO 8601 format string with time zone + (ending with `Z` for UTC). + * Byte arrays (`bytes`) are written as a string with Base64 encoding. + * UUIDs (`uuid.UUID`) are written as a UUID string as per RFC 4122. + * Enumerations yield their enumeration value. + * Containers (e.g. `list`, `dict`, `set`, `tuple`) are processed recursively. + * Complex objects with properties (including data class types) generate dictionaries of key-value pairs. + + :raises TypeError: A serializer engine cannot be constructed for the input type. + """ + + if context is None: + if isinstance(typ, type): + context = sys.modules[typ.__module__] + + return _get_serializer(typ, context) + + +def _get_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer: + if isinstance(typ, (str, typing.ForwardRef)): + if context is None: + raise TypeError(f"missing context for evaluating type: {typ}") + + typ = evaluate_type(typ, context) + + if isinstance(typ, type): + return _fetch_serializer(typ) + else: + # special forms are not always hashable + return _create_serializer(typ, context) + + +@functools.lru_cache(maxsize=None) +def _fetch_serializer(typ: type) -> Serializer: + context = sys.modules[typ.__module__] + return _create_serializer(typ, context) + + +def _create_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer: + # check for well-known types + if typ is type(None): + return NoneSerializer() + elif typ is bool: + return BoolSerializer() + elif typ is int: + return IntSerializer() + elif typ is float: + return FloatSerializer() + elif typ is str: + return StringSerializer() + elif typ is bytes: + return BytesSerializer() + elif typ is datetime.datetime: + return DateTimeSerializer() + elif typ is datetime.date: + return DateSerializer() + elif typ is datetime.time: + return TimeSerializer() + elif typ is uuid.UUID: + return UUIDSerializer() + elif typ is ipaddress.IPv4Address: + return IPv4Serializer() + elif typ is ipaddress.IPv6Address: + return IPv6Serializer() + + # dynamically-typed collection types + if typ is list: + return UntypedListSerializer() + elif typ is dict: + return UntypedDictSerializer() + elif typ is set: + return UntypedSetSerializer() + elif typ is tuple: + return UntypedTupleSerializer() + + # generic types (e.g. list, dict, set, etc.) + origin_type = typing.get_origin(typ) + if origin_type is list: + (list_item_type,) = typing.get_args(typ) # unpack single tuple element + return TypedListSerializer(list_item_type, context) + elif origin_type is dict: + key_type, value_type = typing.get_args(typ) + if key_type is str: + return TypedStringDictSerializer(value_type, context) + elif issubclass(key_type, enum.Enum): + return TypedEnumDictSerializer(key_type, value_type, context) + elif origin_type is set: + (set_member_type,) = typing.get_args(typ) # unpack single tuple element + return TypedSetSerializer(set_member_type, context) + elif origin_type is tuple: + return TypedTupleSerializer(typing.get_args(typ), context) + elif origin_type is Union: + return UnionSerializer() + elif origin_type is Literal: + return LiteralSerializer(typing.get_args(typ), context) + + if is_type_annotated(typ): + return create_serializer(unwrap_annotated_type(typ)) + + # check if object has custom serialization method + convert_func = getattr(typ, "to_json", None) + if callable(convert_func): + return CustomSerializer(convert_func) + + if is_type_enum(typ): + return EnumSerializer() + if is_dataclass_type(typ): + return DataclassSerializer(typ, context) + if is_named_tuple_type(typ): + if getattr(typ, "__annotations__", None): + return TypedNamedTupleSerializer(typ, context) + else: + return UntypedNamedTupleSerializer(typ) + + # fail early if caller passes an object with an exotic type + if not isinstance(typ, type) or typ is FunctionType or typ is MethodType or typ is type or typ is ModuleType: + raise TypeError(f"object of type {typ} cannot be represented in JSON") + + if get_resolved_hints(typ): + return TypedClassSerializer(typ, context) + else: + return UntypedClassSerializer() + + +def object_to_json(obj: Any) -> JsonType: + """ + Converts a Python object to a representation that can be exported to JSON. + + * Fundamental types (e.g. numeric types) are written as is. + * Date and time types are serialized in the ISO 8601 format with time zone. + * A byte array is written as a string with Base64 encoding. + * UUIDs are written as a UUID string. + * Enumerations are written as their value. + * Containers (e.g. `list`, `dict`, `set`, `tuple`) are exported recursively. + * Objects with properties (including data class types) are converted to a dictionaries of key-value pairs. + """ + + typ: type = type(obj) + generator = create_serializer(typ) + return generator.generate(obj) diff --git a/llama_stack/strong_typing/slots.py b/llama_stack/strong_typing/slots.py new file mode 100644 index 000000000..c1a3293d8 --- /dev/null +++ b/llama_stack/strong_typing/slots.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Tuple, Type, TypeVar + +T = TypeVar("T") + + +class SlotsMeta(type): + def __new__(cls: Type[T], name: str, bases: Tuple[type, ...], ns: Dict[str, Any]) -> T: + # caller may have already provided slots, in which case just retain them and keep going + slots: Tuple[str, ...] = ns.get("__slots__", ()) + + # add fields with type annotations to slots + annotations: Dict[str, Any] = ns.get("__annotations__", {}) + members = tuple(member for member in annotations.keys() if member not in slots) + + # assign slots + ns["__slots__"] = slots + tuple(members) + return super().__new__(cls, name, bases, ns) # type: ignore + + +class Slots(metaclass=SlotsMeta): + pass diff --git a/llama_stack/strong_typing/topological.py b/llama_stack/strong_typing/topological.py new file mode 100644 index 000000000..28bf4bd0f --- /dev/null +++ b/llama_stack/strong_typing/topological.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Type-safe data interchange for Python data classes. + +:see: https://github.com/hunyadi/strong_typing +""" + +from typing import Callable, Dict, Iterable, List, Optional, Set, TypeVar + +from .inspection import TypeCollector + +T = TypeVar("T") + + +def topological_sort(graph: Dict[T, Set[T]]) -> List[T]: + """ + Performs a topological sort of a graph. + + Nodes with no outgoing edges are first. Nodes with no incoming edges are last. + The topological ordering is not unique. + + :param graph: A dictionary of mappings from nodes to adjacent nodes. Keys and set members must be hashable. + :returns: The list of nodes in topological order. + """ + + # empty list that will contain the sorted nodes (in reverse order) + ordered: List[T] = [] + + seen: Dict[T, bool] = {} + + def _visit(n: T) -> None: + status = seen.get(n) + if status is not None: + if status: # node has a permanent mark + return + else: # node has a temporary mark + raise RuntimeError(f"cycle detected in graph for node {n}") + + seen[n] = False # apply temporary mark + for m in graph[n]: # visit all adjacent nodes + if m != n: # ignore self-referencing nodes + _visit(m) + + seen[n] = True # apply permanent mark + ordered.append(n) + + for n in graph.keys(): + _visit(n) + + return ordered + + +def type_topological_sort( + types: Iterable[type], + dependency_fn: Optional[Callable[[type], Iterable[type]]] = None, +) -> List[type]: + """ + Performs a topological sort of a list of types. + + Types that don't depend on other types (i.e. fundamental types) are first. Types on which no other types depend + are last. The topological ordering is not unique. + + :param types: A list of types (simple or composite). + :param dependency_fn: Returns a list of additional dependencies for a class (e.g. classes referenced by a foreign key). + :returns: The list of types in topological order. + """ + + if not all(isinstance(typ, type) for typ in types): + raise TypeError("expected a list of types") + + collector = TypeCollector() + collector.traverse_all(types) + graph = collector.graph + + if dependency_fn: + new_types: Set[type] = set() + for source_type, references in graph.items(): + dependent_types = dependency_fn(source_type) + references.update(dependent_types) + new_types.update(dependent_types) + for new_type in new_types: + graph[new_type] = set() + + return topological_sort(graph) diff --git a/llama_stack/templates/bedrock/bedrock.py b/llama_stack/templates/bedrock/bedrock.py index af1d48b7f..0b294824d 100644 --- a/llama_stack/templates/bedrock/bedrock.py +++ b/llama_stack/templates/bedrock/bedrock.py @@ -6,10 +6,9 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Provider, ToolGroupInput +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py index 870240feb..4f6d0c8f3 100644 --- a/llama_stack/templates/cerebras/cerebras.py +++ b/llama_stack/templates/cerebras/cerebras.py @@ -6,10 +6,9 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) diff --git a/llama_stack/templates/fireworks/fireworks.py b/llama_stack/templates/fireworks/fireworks.py index e2e2ca99c..a6809fef6 100644 --- a/llama_stack/templates/fireworks/fireworks.py +++ b/llama_stack/templates/fireworks/fireworks.py @@ -6,8 +6,6 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( ModelInput, @@ -15,6 +13,7 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index d24c9ed48..ee22b5555 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -6,9 +6,8 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 6d7477c8e..c7a9428af 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -6,14 +6,13 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.distribution.datatypes import ( ModelInput, Provider, ShieldInput, ToolGroupInput, ) +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/together/together.py b/llama_stack/templates/together/together.py index 9ec5b38ba..f7b18e32a 100644 --- a/llama_stack/templates/together/together.py +++ b/llama_stack/templates/together/together.py @@ -6,8 +6,6 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( ModelInput, @@ -15,6 +13,7 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) diff --git a/pyproject.toml b/pyproject.toml index feaae153b..8b0135c70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "fire", "httpx", "huggingface-hub", + "jsonschema", "llama-models>=0.1.2", "llama-stack-client>=0.1.2", "prompt-toolkit", diff --git a/requirements.txt b/requirements.txt index 497feb764..40431e446 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ # This file was autogenerated by uv via the following command: -# uv export --frozen --no-hashes --no-emit-project +# uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt annotated-types==0.7.0 anyio==4.8.0 +attrs==25.1.0 blobfile==3.0.0 certifi==2025.1.31 -chardet==5.2.0 charset-normalizer==3.4.1 click==8.1.8 colorama==0.4.6 ; sys_platform == 'win32' @@ -19,6 +19,8 @@ httpx==0.28.1 huggingface-hub==0.28.1 idna==3.10 jinja2==3.1.5 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 llama-models==0.1.2 llama-stack-client==0.1.2 lxml==5.3.0 @@ -35,14 +37,15 @@ pycryptodomex==3.21.0 pydantic==2.10.6 pydantic-core==2.27.2 pygments==2.19.1 -pypdf==5.2.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 pytz==2025.1 pyyaml==6.0.2 +referencing==0.36.2 regex==2024.11.6 requests==2.32.3 rich==13.9.4 +rpds-py==0.22.3 setuptools==75.8.0 six==1.17.0 sniffio==1.3.1 diff --git a/tests/client-sdk/report.py b/tests/client-sdk/report.py index 543562541..d36fa827f 100644 --- a/tests/client-sdk/report.py +++ b/tests/client-sdk/report.py @@ -13,8 +13,12 @@ from typing import Optional from urllib.parse import urlparse import pytest -from llama_models.datatypes import CoreModelId -from llama_models.sku_list import ( +from metadata import API_MAPS +from pytest import CollectReport +from termcolor import cprint + +from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_list import ( all_registered_models, llama3_1_instruct_models, llama3_2_instruct_models, @@ -22,10 +26,6 @@ from llama_models.sku_list import ( llama3_instruct_models, safety_models, ) -from metadata import API_MAPS -from pytest import CollectReport -from termcolor import cprint - from llama_stack.providers.datatypes import Api from llama_stack.providers.tests.env import get_env_or_fail diff --git a/uv.lock b/uv.lock index 97ae52124..ed1e4bc2d 100644 --- a/uv.lock +++ b/uv.lock @@ -265,7 +265,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -577,7 +577,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -724,6 +724,7 @@ dependencies = [ { name = "fire" }, { name = "httpx" }, { name = "huggingface-hub" }, + { name = "jsonschema" }, { name = "llama-models" }, { name = "llama-stack-client" }, { name = "prompt-toolkit" }, @@ -768,6 +769,7 @@ requires-dist = [ { name = "fire" }, { name = "httpx" }, { name = "huggingface-hub" }, + { name = "jsonschema" }, { name = "llama-models", specifier = ">=0.1.2" }, { name = "llama-stack-client", specifier = ">=0.1.2" }, { name = "myst-parser", marker = "extra == 'docs'" }, @@ -1412,8 +1414,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/74/49f5d20c514ccc631b940cc9dfec45dcce418dc84a98463a2e2ebec33904/pycryptodomex-3.21.0-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:52e23a0a6e61691134aa8c8beba89de420602541afaae70f66e16060fdcd677e", size = 2257982 }, { url = "https://files.pythonhosted.org/packages/92/4b/d33ef74e2cc0025a259936661bb53432c5bbbadc561c5f2e023bcd73ce4c/pycryptodomex-3.21.0-cp36-abi3-win32.whl", hash = "sha256:a3d77919e6ff56d89aada1bd009b727b874d464cb0e2e3f00a49f7d2e709d76e", size = 1779052 }, { url = "https://files.pythonhosted.org/packages/5b/be/7c991840af1184009fc86267160948350d1bf875f153c97bb471ad944e40/pycryptodomex-3.21.0-cp36-abi3-win_amd64.whl", hash = "sha256:b0e9765f93fe4890f39875e6c90c96cb341767833cfa767f41b490b506fa9ec0", size = 1816307 }, - { url = "https://files.pythonhosted.org/packages/af/ac/24125ad36778914a36f08d61ba5338cb9159382c638d9761ee19c8de822c/pycryptodomex-3.21.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:feaecdce4e5c0045e7a287de0c4351284391fe170729aa9182f6bd967631b3a8", size = 1694999 }, - { url = "https://files.pythonhosted.org/packages/93/73/be7a54a5903508070e5508925ba94493a1f326cfeecfff750e3eb250ea28/pycryptodomex-3.21.0-pp27-pypy_73-win32.whl", hash = "sha256:365aa5a66d52fd1f9e0530ea97f392c48c409c2f01ff8b9a39c73ed6f527d36c", size = 1769437 }, { url = "https://files.pythonhosted.org/packages/e5/9f/39a6187f3986841fa6a9f35c6fdca5030ef73ff708b45a993813a51d7d10/pycryptodomex-3.21.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:3efddfc50ac0ca143364042324046800c126a1d63816d532f2e19e6f2d8c0c31", size = 1619607 }, { url = "https://files.pythonhosted.org/packages/f8/70/60bb08e9e9841b18d4669fb69d84b64ce900aacd7eb0ebebd4c7b9bdecd3/pycryptodomex-3.21.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0df2608682db8279a9ebbaf05a72f62a321433522ed0e499bc486a6889b96bf3", size = 1653571 }, { url = "https://files.pythonhosted.org/packages/c9/6f/191b73509291c5ff0dddec9cc54797b1d73303c12b2e4017b24678e57099/pycryptodomex-3.21.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5823d03e904ea3e53aebd6799d6b8ec63b7675b5d2f4a4bd5e3adcb512d03b37", size = 1691548 }, @@ -2305,7 +2305,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [