forked from phoenix-oss/llama-stack-mirror
This is yet another of those large PRs (hopefully we will have less and less of them as things mature fast). This one introduces substantial improvements and some simplifications to the stack. Most important bits: * Agents reference implementation now has support for session / turn persistence. The default implementation uses sqlite but there's also support for using Redis. * We have re-architected the structure of the Stack APIs to allow for more flexible routing. The motivating use cases are: - routing model A to ollama and model B to a remote provider like Together - routing shield A to local impl while shield B to a remote provider like Bedrock - routing a vector memory bank to Weaviate while routing a keyvalue memory bank to Redis * Support for provider specific parameters to be passed from the clients. A client can pass data using `x_llamastack_provider_data` parameter which can be type-checked and provided to the Adapter implementations.
230 lines
5.8 KiB
Python
230 lines
5.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
"""
|
|
Type-safe data interchange for Python data classes.
|
|
|
|
:see: https://github.com/hunyadi/strong_typing
|
|
"""
|
|
|
|
import dataclasses
|
|
import sys
|
|
from dataclasses import is_dataclass
|
|
from typing import Callable, Dict, Optional, overload, Type, TypeVar, Union
|
|
|
|
if sys.version_info >= (3, 9):
|
|
from typing import Annotated as Annotated
|
|
else:
|
|
from typing_extensions import Annotated as Annotated
|
|
|
|
if sys.version_info >= (3, 10):
|
|
from typing import TypeAlias as TypeAlias
|
|
else:
|
|
from typing_extensions import TypeAlias as TypeAlias
|
|
|
|
if sys.version_info >= (3, 11):
|
|
from typing import dataclass_transform as dataclass_transform
|
|
else:
|
|
from typing_extensions import dataclass_transform as dataclass_transform
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def _compact_dataclass_repr(obj: object) -> str:
|
|
"""
|
|
Compact data-class representation where positional arguments are used instead of keyword arguments.
|
|
|
|
:param obj: A data-class object.
|
|
:returns: A string that matches the pattern `Class(arg1, arg2, ...)`.
|
|
"""
|
|
|
|
if is_dataclass(obj):
|
|
arglist = ", ".join(
|
|
repr(getattr(obj, field.name)) for field in dataclasses.fields(obj)
|
|
)
|
|
return f"{obj.__class__.__name__}({arglist})"
|
|
else:
|
|
return obj.__class__.__name__
|
|
|
|
|
|
class CompactDataClass:
|
|
"A data class whose repr() uses positional rather than keyword arguments."
|
|
|
|
def __repr__(self) -> str:
|
|
return _compact_dataclass_repr(self)
|
|
|
|
|
|
@overload
|
|
def typeannotation(cls: Type[T], /) -> Type[T]: ...
|
|
|
|
|
|
@overload
|
|
def typeannotation(
|
|
cls: None, *, eq: bool = True, order: bool = False
|
|
) -> Callable[[Type[T]], Type[T]]: ...
|
|
|
|
|
|
@dataclass_transform(eq_default=True, order_default=False)
|
|
def typeannotation(
|
|
cls: Optional[Type[T]] = None, *, eq: bool = True, order: bool = False
|
|
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
|
|
"""
|
|
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
|
|
|
|
:param cls: The data-class type to transform into a type annotation.
|
|
:param eq: Whether to generate functions to support equality comparison.
|
|
:param order: Whether to generate functions to support ordering.
|
|
:returns: A data-class type, or a wrapper for data-class types.
|
|
"""
|
|
|
|
def wrap(cls: Type[T]) -> Type[T]:
|
|
setattr(cls, "__repr__", _compact_dataclass_repr)
|
|
if not dataclasses.is_dataclass(cls):
|
|
cls = dataclasses.dataclass( # type: ignore[call-overload]
|
|
cls,
|
|
init=True,
|
|
repr=False,
|
|
eq=eq,
|
|
order=order,
|
|
unsafe_hash=False,
|
|
frozen=True,
|
|
)
|
|
return cls
|
|
|
|
# see if decorator is used as @typeannotation or @typeannotation()
|
|
if cls is None:
|
|
# called with parentheses
|
|
return wrap
|
|
else:
|
|
# called without parentheses
|
|
return wrap(cls)
|
|
|
|
|
|
@typeannotation
|
|
class Alias:
|
|
"Alternative name of a property, typically used in JSON serialization."
|
|
|
|
name: str
|
|
|
|
|
|
@typeannotation
|
|
class Signed:
|
|
"Signedness of an integer type."
|
|
|
|
is_signed: bool
|
|
|
|
|
|
@typeannotation
|
|
class Storage:
|
|
"Number of bytes the binary representation of an integer type takes, e.g. 4 bytes for an int32."
|
|
|
|
bytes: int
|
|
|
|
|
|
@typeannotation
|
|
class IntegerRange:
|
|
"Minimum and maximum value of an integer. The range is inclusive."
|
|
|
|
minimum: int
|
|
maximum: int
|
|
|
|
|
|
@typeannotation
|
|
class Precision:
|
|
"Precision of a floating-point value."
|
|
|
|
significant_digits: int
|
|
decimal_digits: int = 0
|
|
|
|
@property
|
|
def integer_digits(self) -> int:
|
|
return self.significant_digits - self.decimal_digits
|
|
|
|
|
|
@typeannotation
|
|
class TimePrecision:
|
|
"""
|
|
Precision of a timestamp or time interval.
|
|
|
|
:param decimal_digits: Number of fractional digits retained in the sub-seconds field for a timestamp.
|
|
"""
|
|
|
|
decimal_digits: int = 0
|
|
|
|
|
|
@typeannotation
|
|
class Length:
|
|
"Exact length of a string."
|
|
|
|
value: int
|
|
|
|
|
|
@typeannotation
|
|
class MinLength:
|
|
"Minimum length of a string."
|
|
|
|
value: int
|
|
|
|
|
|
@typeannotation
|
|
class MaxLength:
|
|
"Maximum length of a string."
|
|
|
|
value: int
|
|
|
|
|
|
@typeannotation
|
|
class SpecialConversion:
|
|
"Indicates that the annotated type is subject to custom conversion rules."
|
|
|
|
|
|
int8: TypeAlias = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)]
|
|
int16: TypeAlias = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)]
|
|
int32: TypeAlias = Annotated[
|
|
int,
|
|
Signed(True),
|
|
Storage(4),
|
|
IntegerRange(-2147483648, 2147483647),
|
|
]
|
|
int64: TypeAlias = Annotated[
|
|
int,
|
|
Signed(True),
|
|
Storage(8),
|
|
IntegerRange(-9223372036854775808, 9223372036854775807),
|
|
]
|
|
|
|
uint8: TypeAlias = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)]
|
|
uint16: TypeAlias = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)]
|
|
uint32: TypeAlias = Annotated[
|
|
int,
|
|
Signed(False),
|
|
Storage(4),
|
|
IntegerRange(0, 4294967295),
|
|
]
|
|
uint64: TypeAlias = Annotated[
|
|
int,
|
|
Signed(False),
|
|
Storage(8),
|
|
IntegerRange(0, 18446744073709551615),
|
|
]
|
|
|
|
float32: TypeAlias = Annotated[float, Storage(4)]
|
|
float64: TypeAlias = Annotated[float, Storage(8)]
|
|
|
|
# maps globals of type Annotated[T, ...] defined in this module to their string names
|
|
_auxiliary_types: Dict[object, str] = {}
|
|
module = sys.modules[__name__]
|
|
for var in dir(module):
|
|
typ = getattr(module, var)
|
|
if getattr(typ, "__metadata__", None) is not None:
|
|
# type is Annotated[T, ...]
|
|
_auxiliary_types[typ] = var
|
|
|
|
|
|
def get_auxiliary_format(data_type: object) -> Optional[str]:
|
|
"Returns the JSON format string corresponding to an auxiliary type."
|
|
|
|
return _auxiliary_types.get(data_type)
|