feat: use XDG directory standards

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-03 18:48:53 +02:00
parent 9736f096f6
commit 407c3e3bad
50 changed files with 5611 additions and 508 deletions

View file

@ -7,6 +7,7 @@
import argparse
from .download import Download
from .migrate_xdg import MigrateXDG
from .model import ModelParser
from .stack import StackParser
from .stack.utils import print_subcommand_description
@ -34,6 +35,7 @@ class LlamaCLIParser:
StackParser.create(subparsers)
Download.create(subparsers)
VerifyDownload.create(subparsers)
MigrateXDG.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -0,0 +1,168 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import shutil
import sys
from pathlib import Path
from llama_stack.distribution.utils.xdg_utils import (
get_llama_stack_config_dir,
get_llama_stack_data_dir,
get_llama_stack_state_dir,
)
from .subcommand import Subcommand
class MigrateXDG(Subcommand):
"""CLI command for migrating from legacy ~/.llama to XDG-compliant directories."""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"migrate-xdg",
prog="llama migrate-xdg",
description="Migrate from legacy ~/.llama to XDG-compliant directories",
formatter_class=argparse.RawTextHelpFormatter,
)
self.parser.add_argument(
"--dry-run", action="store_true", help="Show what would be done without actually moving files"
)
self.parser.set_defaults(func=self._run_migrate_xdg_cmd)
@staticmethod
def create(subparsers: argparse._SubParsersAction):
return MigrateXDG(subparsers)
def _run_migrate_xdg_cmd(self, args: argparse.Namespace) -> None:
"""Run the migrate-xdg command."""
if not migrate_to_xdg(dry_run=args.dry_run):
sys.exit(1)
def migrate_to_xdg(dry_run: bool = False) -> bool:
"""
Migrate from legacy ~/.llama to XDG-compliant directories.
Args:
dry_run: If True, only show what would be done without actually moving files
Returns:
bool: True if migration was successful or not needed, False otherwise
"""
legacy_path = Path.home() / ".llama"
if not legacy_path.exists():
print("No legacy ~/.llama directory found. Nothing to migrate.")
return True
# Check if we're already using XDG paths
config_dir = get_llama_stack_config_dir()
data_dir = get_llama_stack_data_dir()
state_dir = get_llama_stack_state_dir()
if str(config_dir) == str(legacy_path):
print("Already using legacy directory. No migration needed.")
return True
print(f"Found legacy directory at: {legacy_path}")
print("Will migrate to XDG-compliant directories:")
print(f" Config: {config_dir}")
print(f" Data: {data_dir}")
print(f" State: {state_dir}")
print()
# Define migration mapping
migrations = [
# (source_subdir, target_base_dir, description)
("distributions", config_dir, "Distribution configurations"),
("providers.d", config_dir, "External provider configurations"),
("checkpoints", data_dir, "Model checkpoints"),
("runtime", state_dir, "Runtime state files"),
]
# Check what needs to be migrated
items_to_migrate = []
for subdir, target_base, description in migrations:
source_path = legacy_path / subdir
if source_path.exists():
target_path = target_base / subdir
items_to_migrate.append((source_path, target_path, description))
if not items_to_migrate:
print("No items found to migrate.")
return True
print("Items to migrate:")
for source_path, target_path, description in items_to_migrate:
print(f" {description}: {source_path} -> {target_path}")
if dry_run:
print("\nDry run mode: No files will be moved.")
return True
# Ask for confirmation
response = input("\nDo you want to proceed with the migration? (y/N): ")
if response.lower() not in ["y", "yes"]:
print("Migration cancelled.")
return False
# Perform the migration
print("\nMigrating files...")
for source_path, target_path, description in items_to_migrate:
try:
# Create target directory if it doesn't exist
target_path.parent.mkdir(parents=True, exist_ok=True)
# Check if target already exists
if target_path.exists():
print(f" Warning: Target already exists: {target_path}")
print(f" Skipping {description}")
continue
# Move the directory
shutil.move(str(source_path), str(target_path))
print(f" Moved {description}: {source_path} -> {target_path}")
except Exception as e:
print(f" Error migrating {description}: {e}")
return False
# Check if legacy directory is now empty (except for hidden files)
remaining_items = [item for item in legacy_path.iterdir() if not item.name.startswith(".")]
if not remaining_items:
print(f"\nMigration complete! Legacy directory {legacy_path} is now empty.")
response = input("Remove empty legacy directory? (y/N): ")
if response.lower() in ["y", "yes"]:
try:
shutil.rmtree(legacy_path)
print(f"Removed empty legacy directory: {legacy_path}")
except Exception as e:
print(f"Could not remove legacy directory: {e}")
else:
print(f"\nMigration complete! Some items remain in legacy directory: {remaining_items}")
print("\nMigration successful!")
print("You may need to update any custom scripts or configurations that reference the old paths.")
return True
def main():
parser = argparse.ArgumentParser(description="Migrate from legacy ~/.llama to XDG-compliant directories")
parser.add_argument("--dry-run", action="store_true", help="Show what would be done without actually moving files")
args = parser.parse_args()
if not migrate_to_xdg(dry_run=args.dry_run):
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -7,12 +7,35 @@
import os
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
from .xdg_utils import (
get_llama_stack_config_dir,
get_llama_stack_data_dir,
get_llama_stack_state_dir,
)
# Base directory for all llama-stack configuration
# This now uses XDG-compliant paths with backwards compatibility
LLAMA_STACK_CONFIG_DIR = get_llama_stack_config_dir()
# Distribution configurations - stored in config directory
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
# Model checkpoints - stored in data directory (persistent data)
DEFAULT_CHECKPOINT_DIR = get_llama_stack_data_dir() / "checkpoints"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
# Runtime data - stored in state directory
RUNTIME_BASE_DIR = get_llama_stack_state_dir() / "runtime"
# External providers - stored in config directory
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"
# Legacy compatibility: if the legacy environment variable is set, use it for all paths
# This ensures that existing installations continue to work
legacy_config_dir = os.getenv("LLAMA_STACK_CONFIG_DIR")
if legacy_config_dir:
legacy_base = Path(legacy_config_dir)
LLAMA_STACK_CONFIG_DIR = legacy_base
DISTRIBS_BASE_DIR = legacy_base / "distributions"
DEFAULT_CHECKPOINT_DIR = legacy_base / "checkpoints"
RUNTIME_BASE_DIR = legacy_base / "runtime"
EXTERNAL_PROVIDERS_DIR = legacy_base / "providers.d"

View file

@ -0,0 +1,216 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
def get_xdg_config_home() -> Path:
"""
Get the XDG config home directory.
Returns:
Path: XDG_CONFIG_HOME if set, otherwise ~/.config
"""
return Path(os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")))
def get_xdg_data_home() -> Path:
"""
Get the XDG data home directory.
Returns:
Path: XDG_DATA_HOME if set, otherwise ~/.local/share
"""
return Path(os.environ.get("XDG_DATA_HOME", os.path.expanduser("~/.local/share")))
def get_xdg_cache_home() -> Path:
"""
Get the XDG cache home directory.
Returns:
Path: XDG_CACHE_HOME if set, otherwise ~/.cache
"""
return Path(os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")))
def get_xdg_state_home() -> Path:
"""
Get the XDG state home directory.
Returns:
Path: XDG_STATE_HOME if set, otherwise ~/.local/state
"""
return Path(os.environ.get("XDG_STATE_HOME", os.path.expanduser("~/.local/state")))
def get_llama_stack_config_dir() -> Path:
"""
Get the llama-stack configuration directory.
This function provides backwards compatibility by checking for the legacy
LLAMA_STACK_CONFIG_DIR environment variable first, then falling back to
XDG-compliant paths.
Returns:
Path: Configuration directory for llama-stack
"""
# Check for legacy environment variable first for backwards compatibility
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
if legacy_dir:
return Path(legacy_dir)
# Check if legacy ~/.llama directory exists and contains data
legacy_path = Path.home() / ".llama"
if legacy_path.exists() and any(legacy_path.iterdir()):
return legacy_path
# Use XDG-compliant path
return get_xdg_config_home() / "llama-stack"
def get_llama_stack_data_dir() -> Path:
"""
Get the llama-stack data directory.
This is used for persistent data like model checkpoints.
Returns:
Path: Data directory for llama-stack
"""
# Check for legacy environment variable first for backwards compatibility
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
if legacy_dir:
return Path(legacy_dir)
# Check if legacy ~/.llama directory exists and contains data
legacy_path = Path.home() / ".llama"
if legacy_path.exists() and any(legacy_path.iterdir()):
return legacy_path
# Use XDG-compliant path
return get_xdg_data_home() / "llama-stack"
def get_llama_stack_cache_dir() -> Path:
"""
Get the llama-stack cache directory.
This is used for temporary/cache data.
Returns:
Path: Cache directory for llama-stack
"""
# Check for legacy environment variable first for backwards compatibility
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
if legacy_dir:
return Path(legacy_dir)
# Check if legacy ~/.llama directory exists and contains data
legacy_path = Path.home() / ".llama"
if legacy_path.exists() and any(legacy_path.iterdir()):
return legacy_path
# Use XDG-compliant path
return get_xdg_cache_home() / "llama-stack"
def get_llama_stack_state_dir() -> Path:
"""
Get the llama-stack state directory.
This is used for runtime state data.
Returns:
Path: State directory for llama-stack
"""
# Check for legacy environment variable first for backwards compatibility
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
if legacy_dir:
return Path(legacy_dir)
# Check if legacy ~/.llama directory exists and contains data
legacy_path = Path.home() / ".llama"
if legacy_path.exists() and any(legacy_path.iterdir()):
return legacy_path
# Use XDG-compliant path
return get_xdg_state_home() / "llama-stack"
def get_xdg_compliant_path(path_type: str, subdirectory: str | None = None, legacy_fallback: bool = True) -> Path:
"""
Get an XDG-compliant path for a given type.
Args:
path_type: Type of path ('config', 'data', 'cache', 'state')
subdirectory: Optional subdirectory within the base path
legacy_fallback: Whether to check for legacy ~/.llama directory
Returns:
Path: XDG-compliant path
Raises:
ValueError: If path_type is not recognized
"""
path_map = {
"config": get_llama_stack_config_dir,
"data": get_llama_stack_data_dir,
"cache": get_llama_stack_cache_dir,
"state": get_llama_stack_state_dir,
}
if path_type not in path_map:
raise ValueError(f"Unknown path type: {path_type}. Must be one of: {list(path_map.keys())}")
base_path = path_map[path_type]()
if subdirectory:
return base_path / subdirectory
return base_path
def migrate_legacy_directory() -> bool:
"""
Migrate from legacy ~/.llama directory to XDG-compliant directories.
This function helps users migrate their existing data to the new
XDG-compliant structure.
Returns:
bool: True if migration was successful or not needed, False otherwise
"""
legacy_path = Path.home() / ".llama"
if not legacy_path.exists():
return True # No migration needed
print(f"Found legacy directory at {legacy_path}")
print("Consider migrating to XDG-compliant directories:")
print(f" Config: {get_llama_stack_config_dir()}")
print(f" Data: {get_llama_stack_data_dir()}")
print(f" Cache: {get_llama_stack_cache_dir()}")
print(f" State: {get_llama_stack_state_dir()}")
print("Migration can be done by moving the appropriate subdirectories.")
return True
def ensure_directory_exists(path: Path) -> None:
"""
Ensure a directory exists, creating it if necessary.
Args:
path: Path to the directory
"""
path.mkdir(parents=True, exist_ok=True)

View file

@ -4,6 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64

View file

@ -12,23 +12,12 @@ Type-safe data interchange for Python data classes.
import dataclasses
import sys
from collections.abc import Callable
from dataclasses import is_dataclass
from typing import Callable, Dict, Optional, Type, TypeVar, Union, overload
if sys.version_info >= (3, 9):
from typing import Annotated as Annotated
else:
from typing_extensions import Annotated as Annotated
if sys.version_info >= (3, 10):
from typing import TypeAlias as TypeAlias
else:
from typing_extensions import TypeAlias as TypeAlias
if sys.version_info >= (3, 11):
from typing import dataclass_transform as dataclass_transform
else:
from typing_extensions import dataclass_transform as dataclass_transform
from typing import Annotated as Annotated
from typing import TypeAlias as TypeAlias
from typing import TypeVar, overload
from typing import dataclass_transform as dataclass_transform
T = TypeVar("T")
@ -56,17 +45,17 @@ class CompactDataClass:
@overload
def typeannotation(cls: Type[T], /) -> Type[T]: ...
def typeannotation(cls: type[T], /) -> type[T]: ...
@overload
def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[Type[T]], Type[T]]: ...
def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[type[T]], type[T]]: ...
@dataclass_transform(eq_default=True, order_default=False)
def typeannotation(
cls: Optional[Type[T]] = None, *, eq: bool = True, order: bool = False
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
cls: type[T] | None = None, *, eq: bool = True, order: bool = False
) -> type[T] | Callable[[type[T]], type[T]]:
"""
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
@ -76,7 +65,7 @@ def typeannotation(
:returns: A data-class type, or a wrapper for data-class types.
"""
def wrap(cls: Type[T]) -> Type[T]:
def wrap(cls: type[T]) -> type[T]:
# mypy fails to equate bound-y functions (first argument interpreted as
# the bound object) with class methods, hence the `ignore` directive.
cls.__repr__ = _compact_dataclass_repr # type: ignore[method-assign]
@ -179,41 +168,41 @@ class SpecialConversion:
"Indicates that the annotated type is subject to custom conversion rules."
int8: TypeAlias = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)]
int16: TypeAlias = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)]
int32: TypeAlias = Annotated[
type int8 = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)]
type int16 = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)]
type int32 = Annotated[
int,
Signed(True),
Storage(4),
IntegerRange(-2147483648, 2147483647),
]
int64: TypeAlias = Annotated[
type int64 = Annotated[
int,
Signed(True),
Storage(8),
IntegerRange(-9223372036854775808, 9223372036854775807),
]
uint8: TypeAlias = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)]
uint16: TypeAlias = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)]
uint32: TypeAlias = Annotated[
type uint8 = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)]
type uint16 = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)]
type uint32 = Annotated[
int,
Signed(False),
Storage(4),
IntegerRange(0, 4294967295),
]
uint64: TypeAlias = Annotated[
type uint64 = Annotated[
int,
Signed(False),
Storage(8),
IntegerRange(0, 18446744073709551615),
]
float32: TypeAlias = Annotated[float, Storage(4)]
float64: TypeAlias = Annotated[float, Storage(8)]
type float32 = Annotated[float, Storage(4)]
type float64 = Annotated[float, Storage(8)]
# maps globals of type Annotated[T, ...] defined in this module to their string names
_auxiliary_types: Dict[object, str] = {}
_auxiliary_types: dict[object, str] = {}
module = sys.modules[__name__]
for var in dir(module):
typ = getattr(module, var)
@ -222,7 +211,7 @@ for var in dir(module):
_auxiliary_types[typ] = var
def get_auxiliary_format(data_type: object) -> Optional[str]:
def get_auxiliary_format(data_type: object) -> str | None:
"Returns the JSON format string corresponding to an auxiliary type."
return _auxiliary_types.get(data_type)

View file

@ -12,12 +12,11 @@ import enum
import ipaddress
import math
import re
import sys
import types
import typing
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Literal, TypeVar, Union
from .auxiliary import (
Alias,
@ -40,57 +39,57 @@ T = TypeVar("T")
@dataclass
class JsonSchemaNode:
title: Optional[str]
description: Optional[str]
title: str | None
description: str | None
@dataclass
class JsonSchemaType(JsonSchemaNode):
type: str
format: Optional[str]
format: str | None
@dataclass
class JsonSchemaBoolean(JsonSchemaType):
type: Literal["boolean"]
const: Optional[bool]
default: Optional[bool]
examples: Optional[List[bool]]
const: bool | None
default: bool | None
examples: list[bool] | None
@dataclass
class JsonSchemaInteger(JsonSchemaType):
type: Literal["integer"]
const: Optional[int]
default: Optional[int]
examples: Optional[List[int]]
enum: Optional[List[int]]
minimum: Optional[int]
maximum: Optional[int]
const: int | None
default: int | None
examples: list[int] | None
enum: list[int] | None
minimum: int | None
maximum: int | None
@dataclass
class JsonSchemaNumber(JsonSchemaType):
type: Literal["number"]
const: Optional[float]
default: Optional[float]
examples: Optional[List[float]]
minimum: Optional[float]
maximum: Optional[float]
exclusiveMinimum: Optional[float]
exclusiveMaximum: Optional[float]
multipleOf: Optional[float]
const: float | None
default: float | None
examples: list[float] | None
minimum: float | None
maximum: float | None
exclusiveMinimum: float | None
exclusiveMaximum: float | None
multipleOf: float | None
@dataclass
class JsonSchemaString(JsonSchemaType):
type: Literal["string"]
const: Optional[str]
default: Optional[str]
examples: Optional[List[str]]
enum: Optional[List[str]]
minLength: Optional[int]
maxLength: Optional[int]
const: str | None
default: str | None
examples: list[str] | None
enum: list[str] | None
minLength: int | None
maxLength: int | None
@dataclass
@ -102,9 +101,9 @@ class JsonSchemaArray(JsonSchemaType):
@dataclass
class JsonSchemaObject(JsonSchemaType):
type: Literal["object"]
properties: Optional[Dict[str, "JsonSchemaAny"]]
additionalProperties: Optional[bool]
required: Optional[List[str]]
properties: dict[str, "JsonSchemaAny"] | None
additionalProperties: bool | None
required: list[str] | None
@dataclass
@ -114,24 +113,24 @@ class JsonSchemaRef(JsonSchemaNode):
@dataclass
class JsonSchemaAllOf(JsonSchemaNode):
allOf: List["JsonSchemaAny"]
allOf: list["JsonSchemaAny"]
@dataclass
class JsonSchemaAnyOf(JsonSchemaNode):
anyOf: List["JsonSchemaAny"]
anyOf: list["JsonSchemaAny"]
@dataclass
class Discriminator:
propertyName: str
mapping: Dict[str, str]
mapping: dict[str, str]
@dataclass
class JsonSchemaOneOf(JsonSchemaNode):
oneOf: List["JsonSchemaAny"]
discriminator: Optional[Discriminator]
oneOf: list["JsonSchemaAny"]
discriminator: Discriminator | None
JsonSchemaAny = Union[
@ -149,10 +148,10 @@ JsonSchemaAny = Union[
@dataclass
class JsonSchemaTopLevelObject(JsonSchemaObject):
schema: Annotated[str, Alias("$schema")]
definitions: Optional[Dict[str, JsonSchemaAny]]
definitions: dict[str, JsonSchemaAny] | None
def integer_range_to_type(min_value: float, max_value: float) -> type:
def integer_range_to_type(min_value: float, max_value: float) -> Any:
if min_value >= -(2**15) and max_value < 2**15:
return int16
elif min_value >= -(2**31) and max_value < 2**31:
@ -173,11 +172,11 @@ def enum_safe_name(name: str) -> str:
def enum_values_to_type(
module: types.ModuleType,
name: str,
values: Dict[str, Any],
title: Optional[str] = None,
description: Optional[str] = None,
) -> Type[enum.Enum]:
enum_class: Type[enum.Enum] = enum.Enum(name, values) # type: ignore
values: dict[str, Any],
title: str | None = None,
description: str | None = None,
) -> type[enum.Enum]:
enum_class: type[enum.Enum] = enum.Enum(name, values) # type: ignore
# assign the newly created type to the same module where the defining class is
enum_class.__module__ = module.__name__
@ -330,7 +329,7 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode
type_def = node_to_typedef(module, context, node.items)
if type_def.default is not dataclasses.MISSING:
raise TypeError("disallowed: `default` for array element type")
list_type = List[(type_def.type,)] # type: ignore
list_type = list[(type_def.type,)] # type: ignore
return TypeDef(list_type, dataclasses.MISSING)
elif isinstance(node, JsonSchemaObject):
@ -344,8 +343,8 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode
class_name = context
fields: List[Tuple[str, Any, dataclasses.Field]] = []
params: Dict[str, DocstringParam] = {}
fields: list[tuple[str, Any, dataclasses.Field]] = []
params: dict[str, DocstringParam] = {}
for prop_name, prop_node in node.properties.items():
type_def = node_to_typedef(module, f"{class_name}__{prop_name}", prop_node)
if prop_name in required:
@ -358,10 +357,7 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode
params[prop_name] = DocstringParam(prop_name, prop_desc)
fields.sort(key=lambda t: t[2].default is not dataclasses.MISSING)
if sys.version_info >= (3, 12):
class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__)
else:
class_type = dataclasses.make_dataclass(class_name, fields, namespace={"__module__": module.__name__})
class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__)
class_type.__doc__ = str(
Docstring(
short_description=node.title,
@ -388,7 +384,7 @@ class SchemaFlatteningOptions:
recursive: bool = False
def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions] = None) -> Schema:
def flatten_schema(schema: Schema, *, options: SchemaFlatteningOptions | None = None) -> Schema:
top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema))
flattener = SchemaFlattener(options)
obj = flattener.flatten(top_node)
@ -398,7 +394,7 @@ def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions]
class SchemaFlattener:
options: SchemaFlatteningOptions
def __init__(self, options: Optional[SchemaFlatteningOptions] = None) -> None:
def __init__(self, options: SchemaFlatteningOptions | None = None) -> None:
self.options = options or SchemaFlatteningOptions()
def flatten(self, source_node: JsonSchemaObject) -> JsonSchemaObject:
@ -406,10 +402,10 @@ class SchemaFlattener:
return source_node
source_props = source_node.properties or {}
target_props: Dict[str, JsonSchemaAny] = {}
target_props: dict[str, JsonSchemaAny] = {}
source_reqs = source_node.required or []
target_reqs: List[str] = []
target_reqs: list[str] = []
for name, prop in source_props.items():
if not isinstance(prop, JsonSchemaObject):

View file

@ -10,7 +10,7 @@ Type-safe data interchange for Python data classes.
:see: https://github.com/hunyadi/strong_typing
"""
from typing import Dict, List, Union
from typing import Union
class JsonObject:
@ -28,8 +28,8 @@ JsonType = Union[
int,
float,
str,
Dict[str, "JsonType"],
List["JsonType"],
dict[str, "JsonType"],
list["JsonType"],
]
# a JSON type that cannot contain `null` values
@ -38,9 +38,9 @@ StrictJsonType = Union[
int,
float,
str,
Dict[str, "StrictJsonType"],
List["StrictJsonType"],
dict[str, "StrictJsonType"],
list["StrictJsonType"],
]
# a meta-type that captures the object type in a JSON schema
Schema = Dict[str, JsonType]
Schema = dict[str, JsonType]

View file

@ -20,19 +20,14 @@ import ipaddress
import sys
import typing
import uuid
from collections.abc import Callable
from types import ModuleType
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Literal,
NamedTuple,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
@ -70,7 +65,7 @@ V = TypeVar("V")
class Deserializer(abc.ABC, Generic[T]):
"Parses a JSON value into a Python type."
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
"""
Creates auxiliary parsers that this parser is depending on.
@ -203,19 +198,19 @@ class IPv6Deserializer(Deserializer[ipaddress.IPv6Address]):
return ipaddress.IPv6Address(data)
class ListDeserializer(Deserializer[List[T]]):
class ListDeserializer(Deserializer[list[T]]):
"Recursively de-serializes a JSON array into a Python `list`."
item_type: Type[T]
item_type: type[T]
item_parser: Deserializer
def __init__(self, item_type: Type[T]) -> None:
def __init__(self, item_type: type[T]) -> None:
self.item_type = item_type
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.item_parser = _get_deserializer(self.item_type, context)
def parse(self, data: JsonType) -> List[T]:
def parse(self, data: JsonType) -> list[T]:
if not isinstance(data, list):
type_name = python_type_to_str(self.item_type)
raise JsonTypeError(f"type `List[{type_name}]` expects JSON `array` data but instead received: {data}")
@ -223,19 +218,19 @@ class ListDeserializer(Deserializer[List[T]]):
return [self.item_parser.parse(item) for item in data]
class DictDeserializer(Deserializer[Dict[K, V]]):
class DictDeserializer(Deserializer[dict[K, V]]):
"Recursively de-serializes a JSON object into a Python `dict`."
key_type: Type[K]
value_type: Type[V]
key_type: type[K]
value_type: type[V]
value_parser: Deserializer[V]
def __init__(self, key_type: Type[K], value_type: Type[V]) -> None:
def __init__(self, key_type: type[K], value_type: type[V]) -> None:
self.key_type = key_type
self.value_type = value_type
self._check_key_type()
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.value_parser = _get_deserializer(self.value_type, context)
def _check_key_type(self) -> None:
@ -264,48 +259,48 @@ class DictDeserializer(Deserializer[Dict[K, V]]):
value_type_name = python_type_to_str(self.value_type)
return f"Dict[{key_type_name}, {value_type_name}]"
def parse(self, data: JsonType) -> Dict[K, V]:
def parse(self, data: JsonType) -> dict[K, V]:
if not isinstance(data, dict):
raise JsonTypeError(
f"`type `{self.container_type}` expects JSON `object` data but instead received: {data}"
)
return dict(
(self.key_type(key), self.value_parser.parse(value)) # type: ignore[call-arg]
return {
self.key_type(key): self.value_parser.parse(value) # type: ignore[call-arg]
for key, value in data.items()
)
}
class SetDeserializer(Deserializer[Set[T]]):
class SetDeserializer(Deserializer[set[T]]):
"Recursively de-serializes a JSON list into a Python `set`."
member_type: Type[T]
member_type: type[T]
member_parser: Deserializer
def __init__(self, member_type: Type[T]) -> None:
def __init__(self, member_type: type[T]) -> None:
self.member_type = member_type
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.member_parser = _get_deserializer(self.member_type, context)
def parse(self, data: JsonType) -> Set[T]:
def parse(self, data: JsonType) -> set[T]:
if not isinstance(data, list):
type_name = python_type_to_str(self.member_type)
raise JsonTypeError(f"type `Set[{type_name}]` expects JSON `array` data but instead received: {data}")
return set(self.member_parser.parse(item) for item in data)
return {self.member_parser.parse(item) for item in data}
class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
class TupleDeserializer(Deserializer[tuple[Any, ...]]):
"Recursively de-serializes a JSON list into a Python `tuple`."
item_types: Tuple[Type[Any], ...]
item_parsers: Tuple[Deserializer[Any], ...]
item_types: tuple[type[Any], ...]
item_parsers: tuple[Deserializer[Any], ...]
def __init__(self, item_types: Tuple[Type[Any], ...]) -> None:
def __init__(self, item_types: tuple[type[Any], ...]) -> None:
self.item_types = item_types
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.item_parsers = tuple(_get_deserializer(item_type, context) for item_type in self.item_types)
@property
@ -313,7 +308,7 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
type_names = ", ".join(python_type_to_str(item_type) for item_type in self.item_types)
return f"Tuple[{type_names}]"
def parse(self, data: JsonType) -> Tuple[Any, ...]:
def parse(self, data: JsonType) -> tuple[Any, ...]:
if not isinstance(data, list) or len(data) != len(self.item_parsers):
if not isinstance(data, list):
raise JsonTypeError(
@ -331,13 +326,13 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
class UnionDeserializer(Deserializer):
"De-serializes a JSON value (of any type) into a Python union type."
member_types: Tuple[type, ...]
member_parsers: Tuple[Deserializer, ...]
member_types: tuple[type, ...]
member_parsers: tuple[Deserializer, ...]
def __init__(self, member_types: Tuple[type, ...]) -> None:
def __init__(self, member_types: tuple[type, ...]) -> None:
self.member_types = member_types
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.member_parsers = tuple(_get_deserializer(member_type, context) for member_type in self.member_types)
def parse(self, data: JsonType) -> Any:
@ -354,15 +349,15 @@ class UnionDeserializer(Deserializer):
raise JsonKeyError(f"type `Union[{type_names}]` could not be instantiated from: {data}")
def get_literal_properties(typ: type) -> Set[str]:
def get_literal_properties(typ: type) -> set[str]:
"Returns the names of all properties in a class that are of a literal type."
return set(
return {
property_name for property_name, property_type in get_class_properties(typ) if is_type_literal(property_type)
)
}
def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]:
def get_discriminating_properties(types: tuple[type, ...]) -> set[str]:
"Returns a set of properties with literal type that are common across all specified classes."
if not types or not all(isinstance(typ, type) for typ in types):
@ -378,15 +373,15 @@ def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]:
class TaggedUnionDeserializer(Deserializer):
"De-serializes a JSON value with one or more disambiguating properties into a Python union type."
member_types: Tuple[type, ...]
disambiguating_properties: Set[str]
member_parsers: Dict[Tuple[str, Any], Deserializer]
member_types: tuple[type, ...]
disambiguating_properties: set[str]
member_parsers: dict[tuple[str, Any], Deserializer]
def __init__(self, member_types: Tuple[type, ...]) -> None:
def __init__(self, member_types: tuple[type, ...]) -> None:
self.member_types = member_types
self.disambiguating_properties = get_discriminating_properties(member_types)
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.member_parsers = {}
for member_type in self.member_types:
for property_name in self.disambiguating_properties:
@ -435,13 +430,13 @@ class TaggedUnionDeserializer(Deserializer):
class LiteralDeserializer(Deserializer):
"De-serializes a JSON value into a Python literal type."
values: Tuple[Any, ...]
values: tuple[Any, ...]
parser: Deserializer
def __init__(self, values: Tuple[Any, ...]) -> None:
def __init__(self, values: tuple[Any, ...]) -> None:
self.values = values
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
literal_type_tuple = tuple(type(value) for value in self.values)
literal_type_set = set(literal_type_tuple)
if len(literal_type_set) != 1:
@ -464,9 +459,9 @@ class LiteralDeserializer(Deserializer):
class EnumDeserializer(Deserializer[E]):
"Returns an enumeration instance based on the enumeration value read from a JSON value."
enum_type: Type[E]
enum_type: type[E]
def __init__(self, enum_type: Type[E]) -> None:
def __init__(self, enum_type: type[E]) -> None:
self.enum_type = enum_type
def parse(self, data: JsonType) -> E:
@ -504,13 +499,13 @@ class FieldDeserializer(abc.ABC, Generic[T, R]):
self.parser = parser
@abc.abstractmethod
def parse_field(self, data: Dict[str, JsonType]) -> R: ...
def parse_field(self, data: dict[str, JsonType]) -> R: ...
class RequiredFieldDeserializer(FieldDeserializer[T, T]):
"Deserializes a JSON property into a mandatory Python object field."
def parse_field(self, data: Dict[str, JsonType]) -> T:
def parse_field(self, data: dict[str, JsonType]) -> T:
if self.property_name not in data:
raise JsonKeyError(f"missing required property `{self.property_name}` from JSON object: {data}")
@ -520,7 +515,7 @@ class RequiredFieldDeserializer(FieldDeserializer[T, T]):
class OptionalFieldDeserializer(FieldDeserializer[T, Optional[T]]):
"Deserializes a JSON property into an optional Python object field with a default value of `None`."
def parse_field(self, data: Dict[str, JsonType]) -> Optional[T]:
def parse_field(self, data: dict[str, JsonType]) -> T | None:
value = data.get(self.property_name)
if value is not None:
return self.parser.parse(value)
@ -543,7 +538,7 @@ class DefaultFieldDeserializer(FieldDeserializer[T, T]):
super().__init__(property_name, field_name, parser)
self.default_value = default_value
def parse_field(self, data: Dict[str, JsonType]) -> T:
def parse_field(self, data: dict[str, JsonType]) -> T:
value = data.get(self.property_name)
if value is not None:
return self.parser.parse(value)
@ -566,7 +561,7 @@ class DefaultFactoryFieldDeserializer(FieldDeserializer[T, T]):
super().__init__(property_name, field_name, parser)
self.default_factory = default_factory
def parse_field(self, data: Dict[str, JsonType]) -> T:
def parse_field(self, data: dict[str, JsonType]) -> T:
value = data.get(self.property_name)
if value is not None:
return self.parser.parse(value)
@ -578,22 +573,22 @@ class ClassDeserializer(Deserializer[T]):
"Base class for de-serializing class-like types such as data classes, named tuples and regular classes."
class_type: type
property_parsers: List[FieldDeserializer]
property_fields: Set[str]
property_parsers: list[FieldDeserializer]
property_fields: set[str]
def __init__(self, class_type: Type[T]) -> None:
def __init__(self, class_type: type[T]) -> None:
self.class_type = class_type
def assign(self, property_parsers: List[FieldDeserializer]) -> None:
def assign(self, property_parsers: list[FieldDeserializer]) -> None:
self.property_parsers = property_parsers
self.property_fields = set(property_parser.property_name for property_parser in property_parsers)
self.property_fields = {property_parser.property_name for property_parser in property_parsers}
def parse(self, data: JsonType) -> T:
if not isinstance(data, dict):
type_name = python_type_to_str(self.class_type)
raise JsonTypeError(f"`type `{type_name}` expects JSON `object` data but instead received: {data}")
object_data: Dict[str, JsonType] = typing.cast(Dict[str, JsonType], data)
object_data: dict[str, JsonType] = typing.cast(dict[str, JsonType], data)
field_values = {}
for property_parser in self.property_parsers:
@ -619,8 +614,8 @@ class ClassDeserializer(Deserializer[T]):
class NamedTupleDeserializer(ClassDeserializer[NamedTuple]):
"De-serializes a named tuple from a JSON `object`."
def build(self, context: Optional[ModuleType]) -> None:
property_parsers: List[FieldDeserializer] = [
def build(self, context: ModuleType | None) -> None:
property_parsers: list[FieldDeserializer] = [
RequiredFieldDeserializer(field_name, field_name, _get_deserializer(field_type, context))
for field_name, field_type in get_resolved_hints(self.class_type).items()
]
@ -634,13 +629,13 @@ class NamedTupleDeserializer(ClassDeserializer[NamedTuple]):
class DataclassDeserializer(ClassDeserializer[T]):
"De-serializes a data class from a JSON `object`."
def __init__(self, class_type: Type[T]) -> None:
def __init__(self, class_type: type[T]) -> None:
if not dataclasses.is_dataclass(class_type):
raise TypeError("expected: data-class type")
super().__init__(class_type) # type: ignore[arg-type]
def build(self, context: Optional[ModuleType]) -> None:
property_parsers: List[FieldDeserializer] = []
def build(self, context: ModuleType | None) -> None:
property_parsers: list[FieldDeserializer] = []
resolved_hints = get_resolved_hints(self.class_type)
for field in dataclasses.fields(self.class_type):
field_type = resolved_hints[field.name]
@ -651,7 +646,7 @@ class DataclassDeserializer(ClassDeserializer[T]):
has_default_factory = field.default_factory is not dataclasses.MISSING
if is_optional:
required_type: Type[T] = unwrap_optional_type(field_type)
required_type: type[T] = unwrap_optional_type(field_type)
else:
required_type = field_type
@ -691,15 +686,15 @@ class FrozenDataclassDeserializer(DataclassDeserializer[T]):
class TypedClassDeserializer(ClassDeserializer[T]):
"De-serializes a class with type annotations from a JSON `object` by iterating over class properties."
def build(self, context: Optional[ModuleType]) -> None:
property_parsers: List[FieldDeserializer] = []
def build(self, context: ModuleType | None) -> None:
property_parsers: list[FieldDeserializer] = []
for field_name, field_type in get_resolved_hints(self.class_type).items():
property_name = python_field_to_json_property(field_name, field_type)
is_optional = is_type_optional(field_type)
if is_optional:
required_type: Type[T] = unwrap_optional_type(field_type)
required_type: type[T] = unwrap_optional_type(field_type)
else:
required_type = field_type
@ -715,7 +710,7 @@ class TypedClassDeserializer(ClassDeserializer[T]):
super().assign(property_parsers)
def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Deserializer:
def create_deserializer(typ: TypeLike, context: ModuleType | None = None) -> Deserializer:
"""
Creates a de-serializer engine to produce a Python object from an object obtained from a JSON string.
@ -741,15 +736,15 @@ def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) ->
return _get_deserializer(typ, context)
_CACHE: Dict[Tuple[str, str], Deserializer] = {}
_CACHE: dict[tuple[str, str], Deserializer] = {}
def _get_deserializer(typ: TypeLike, context: Optional[ModuleType]) -> Deserializer:
def _get_deserializer(typ: TypeLike, context: ModuleType | None) -> Deserializer:
"Creates or re-uses a de-serializer engine to parse an object obtained from a JSON string."
cache_key = None
if isinstance(typ, (str, typing.ForwardRef)):
if isinstance(typ, str | typing.ForwardRef):
if context is None:
raise TypeError(f"missing context for evaluating type: {typ}")

View file

@ -15,17 +15,12 @@ import collections.abc
import dataclasses
import inspect
import re
import sys
import types
import typing
from collections.abc import Callable
from dataclasses import dataclass
from io import StringIO
from typing import Any, Callable, Dict, Optional, Protocol, Type, TypeVar
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
from typing import Any, Protocol, TypeGuard, TypeVar
from .inspection import (
DataclassInstance,
@ -110,14 +105,14 @@ class Docstring:
:param returns: The returns declaration extracted from a docstring.
"""
short_description: Optional[str] = None
long_description: Optional[str] = None
params: Dict[str, DocstringParam] = dataclasses.field(default_factory=dict)
returns: Optional[DocstringReturns] = None
raises: Dict[str, DocstringRaises] = dataclasses.field(default_factory=dict)
short_description: str | None = None
long_description: str | None = None
params: dict[str, DocstringParam] = dataclasses.field(default_factory=dict)
returns: DocstringReturns | None = None
raises: dict[str, DocstringRaises] = dataclasses.field(default_factory=dict)
@property
def full_description(self) -> Optional[str]:
def full_description(self) -> str | None:
if self.short_description and self.long_description:
return f"{self.short_description}\n\n{self.long_description}"
elif self.short_description:
@ -158,18 +153,18 @@ class Docstring:
return s
def is_exception(member: object) -> TypeGuard[Type[BaseException]]:
def is_exception(member: object) -> TypeGuard[type[BaseException]]:
return isinstance(member, type) and issubclass(member, BaseException)
def get_exceptions(module: types.ModuleType) -> Dict[str, Type[BaseException]]:
def get_exceptions(module: types.ModuleType) -> dict[str, type[BaseException]]:
"Returns all exception classes declared in a module."
return {name: class_type for name, class_type in inspect.getmembers(module, is_exception)}
return dict(inspect.getmembers(module, is_exception))
class SupportsDoc(Protocol):
__doc__: Optional[str]
__doc__: str | None
def _maybe_unwrap_async_iterator(t):
@ -213,7 +208,7 @@ def parse_type(typ: SupportsDoc) -> Docstring:
# assign exception types
defining_module = inspect.getmodule(typ)
if defining_module:
context: Dict[str, type] = {}
context: dict[str, type] = {}
context.update(get_exceptions(builtins))
context.update(get_exceptions(defining_module))
for exc_name, exc in docstring.raises.items():
@ -262,8 +257,8 @@ def parse_text(text: str) -> Docstring:
else:
long_description = None
params: Dict[str, DocstringParam] = {}
raises: Dict[str, DocstringRaises] = {}
params: dict[str, DocstringParam] = {}
raises: dict[str, DocstringRaises] = {}
returns = None
for match in re.finditer(r"(^:.*?)(?=^:|\Z)", meta_chunk, flags=re.DOTALL | re.MULTILINE):
chunk = match.group(0)
@ -325,7 +320,7 @@ def has_docstring(typ: SupportsDoc) -> bool:
return bool(typ.__doc__)
def get_docstring(typ: SupportsDoc) -> Optional[str]:
def get_docstring(typ: SupportsDoc) -> str | None:
if typ.__doc__ is None:
return None
@ -348,7 +343,7 @@ def check_docstring(typ: SupportsDoc, docstring: Docstring, strict: bool = False
check_function_docstring(typ, docstring, strict)
def check_dataclass_docstring(typ: Type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None:
def check_dataclass_docstring(typ: type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None:
"""
Verifies the doc-string of a data-class type.

View file

@ -22,34 +22,19 @@ import sys
import types
import typing
import uuid
from collections.abc import Callable, Iterable
from typing import (
Annotated,
Any,
Callable,
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Protocol,
Set,
Tuple,
Type,
TypeGuard,
TypeVar,
Union,
runtime_checkable,
)
if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
S = TypeVar("S")
T = TypeVar("T")
K = TypeVar("K")
@ -80,28 +65,20 @@ def _is_type_like(data_type: object) -> bool:
return False
if sys.version_info >= (3, 9):
TypeLike = Union[type, types.GenericAlias, typing.ForwardRef, Any]
TypeLike = Union[type, types.GenericAlias, typing.ForwardRef, Any]
def is_type_like(
data_type: object,
) -> TypeGuard[TypeLike]:
"""
Checks if the object is a type or type-like object (e.g. generic type).
:param data_type: The object to validate.
:returns: True if the object is a type or type-like object.
"""
def is_type_like(
data_type: object,
) -> TypeGuard[TypeLike]:
"""
Checks if the object is a type or type-like object (e.g. generic type).
return _is_type_like(data_type)
:param data_type: The object to validate.
:returns: True if the object is a type or type-like object.
"""
else:
TypeLike = object
def is_type_like(
data_type: object,
) -> bool:
return _is_type_like(data_type)
return _is_type_like(data_type)
def evaluate_member_type(typ: Any, cls: type) -> Any:
@ -129,20 +106,17 @@ def evaluate_type(typ: Any, module: types.ModuleType) -> Any:
# evaluate data-class field whose type annotation is a string
return eval(typ, module.__dict__, locals())
if isinstance(typ, typing.ForwardRef):
if sys.version_info >= (3, 9):
return typ._evaluate(module.__dict__, locals(), recursive_guard=frozenset())
else:
return typ._evaluate(module.__dict__, locals())
return typ._evaluate(module.__dict__, locals(), recursive_guard=frozenset())
else:
return typ
@runtime_checkable
class DataclassInstance(Protocol):
__dataclass_fields__: typing.ClassVar[Dict[str, dataclasses.Field]]
__dataclass_fields__: typing.ClassVar[dict[str, dataclasses.Field]]
def is_dataclass_type(typ: Any) -> TypeGuard[Type[DataclassInstance]]:
def is_dataclass_type(typ: Any) -> TypeGuard[type[DataclassInstance]]:
"True if the argument corresponds to a data class type (but not an instance)."
typ = unwrap_annotated_type(typ)
@ -167,14 +141,14 @@ class DataclassField:
self.default = default
def dataclass_fields(cls: Type[DataclassInstance]) -> Iterable[DataclassField]:
def dataclass_fields(cls: type[DataclassInstance]) -> Iterable[DataclassField]:
"Generates the fields of a data-class resolving forward references."
for field in dataclasses.fields(cls):
yield DataclassField(field.name, evaluate_member_type(field.type, cls), field.default)
def dataclass_field_by_name(cls: Type[DataclassInstance], name: str) -> DataclassField:
def dataclass_field_by_name(cls: type[DataclassInstance], name: str) -> DataclassField:
"Looks up a field in a data-class by its field name."
for field in dataclasses.fields(cls):
@ -190,7 +164,7 @@ def is_named_tuple_instance(obj: Any) -> TypeGuard[NamedTuple]:
return is_named_tuple_type(type(obj))
def is_named_tuple_type(typ: Any) -> TypeGuard[Type[NamedTuple]]:
def is_named_tuple_type(typ: Any) -> TypeGuard[type[NamedTuple]]:
"""
True if the argument corresponds to a named tuple type.
@ -217,26 +191,14 @@ def is_named_tuple_type(typ: Any) -> TypeGuard[Type[NamedTuple]]:
return all(isinstance(n, str) for n in f)
if sys.version_info >= (3, 11):
def is_type_enum(typ: object) -> TypeGuard[type[enum.Enum]]:
"True if the specified type is an enumeration type."
def is_type_enum(typ: object) -> TypeGuard[Type[enum.Enum]]:
"True if the specified type is an enumeration type."
typ = unwrap_annotated_type(typ)
return isinstance(typ, enum.EnumType)
else:
def is_type_enum(typ: object) -> TypeGuard[Type[enum.Enum]]:
"True if the specified type is an enumeration type."
typ = unwrap_annotated_type(typ)
# use an explicit isinstance(..., type) check to filter out special forms like generics
return isinstance(typ, type) and issubclass(typ, enum.Enum)
typ = unwrap_annotated_type(typ)
return isinstance(typ, enum.EnumType)
def enum_value_types(enum_type: Type[enum.Enum]) -> List[type]:
def enum_value_types(enum_type: type[enum.Enum]) -> list[type]:
"""
Returns all unique value types of the `enum.Enum` type in definition order.
"""
@ -246,8 +208,8 @@ def enum_value_types(enum_type: Type[enum.Enum]) -> List[type]:
def extend_enum(
source: Type[enum.Enum],
) -> Callable[[Type[enum.Enum]], Type[enum.Enum]]:
source: type[enum.Enum],
) -> Callable[[type[enum.Enum]], type[enum.Enum]]:
"""
Creates a new enumeration type extending the set of values in an existing type.
@ -255,13 +217,13 @@ def extend_enum(
:returns: A new enumeration type with the extended set of values.
"""
def wrap(extend: Type[enum.Enum]) -> Type[enum.Enum]:
def wrap(extend: type[enum.Enum]) -> type[enum.Enum]:
# create new enumeration type combining the values from both types
values: Dict[str, Any] = {}
values: dict[str, Any] = {}
values.update((e.name, e.value) for e in source)
values.update((e.name, e.value) for e in extend)
# mypy fails to determine that __name__ is always a string; hence the `ignore` directive.
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc]
enum_class: type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc]
# assign the newly created type to the same module where the extending class is defined
enum_class.__module__ = extend.__module__
@ -273,22 +235,13 @@ def extend_enum(
return wrap
if sys.version_info >= (3, 10):
def _is_union_like(typ: object) -> bool:
"True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`."
def _is_union_like(typ: object) -> bool:
"True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`."
return typing.get_origin(typ) is Union or isinstance(typ, types.UnionType)
else:
def _is_union_like(typ: object) -> bool:
"True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`."
return typing.get_origin(typ) is Union
return typing.get_origin(typ) is Union or isinstance(typ, types.UnionType)
def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[Type[Optional[Any]]]:
def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[type[Any | None]]:
"""
True if the type annotation corresponds to an optional type (e.g. `Optional[T]` or `Union[T1,T2,None]`).
@ -309,7 +262,7 @@ def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[Type[Option
return False
def unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
def unwrap_optional_type(typ: type[T | None]) -> type[T]:
"""
Extracts the inner type of an optional type.
@ -320,7 +273,7 @@ def unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
return rewrap_annotated_type(_unwrap_optional_type, typ)
def _unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
def _unwrap_optional_type(typ: type[T | None]) -> type[T]:
"Extracts the type qualified as optional (e.g. returns `T` for `Optional[T]`)."
# Optional[T] is represented internally as Union[T, None]
@ -342,7 +295,7 @@ def is_type_union(typ: object) -> bool:
return False
def unwrap_union_types(typ: object) -> Tuple[object, ...]:
def unwrap_union_types(typ: object) -> tuple[object, ...]:
"""
Extracts the inner types of a union type.
@ -354,7 +307,7 @@ def unwrap_union_types(typ: object) -> Tuple[object, ...]:
return _unwrap_union_types(typ)
def _unwrap_union_types(typ: object) -> Tuple[object, ...]:
def _unwrap_union_types(typ: object) -> tuple[object, ...]:
"Extracts the types in a union (e.g. returns a tuple of types `T1` and `T2` for `Union[T1, T2]`)."
if not _is_union_like(typ):
@ -385,7 +338,7 @@ def unwrap_literal_value(typ: object) -> Any:
return args[0]
def unwrap_literal_values(typ: object) -> Tuple[Any, ...]:
def unwrap_literal_values(typ: object) -> tuple[Any, ...]:
"""
Extracts the constant values captured by a literal type.
@ -397,7 +350,7 @@ def unwrap_literal_values(typ: object) -> Tuple[Any, ...]:
return typing.get_args(typ)
def unwrap_literal_types(typ: object) -> Tuple[type, ...]:
def unwrap_literal_types(typ: object) -> tuple[type, ...]:
"""
Extracts the types of the constant values captured by a literal type.
@ -408,14 +361,14 @@ def unwrap_literal_types(typ: object) -> Tuple[type, ...]:
return tuple(type(t) for t in unwrap_literal_values(typ))
def is_generic_list(typ: object) -> TypeGuard[Type[list]]:
def is_generic_list(typ: object) -> TypeGuard[type[list]]:
"True if the specified type is a generic list, i.e. `List[T]`."
typ = unwrap_annotated_type(typ)
return typing.get_origin(typ) is list
def unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
def unwrap_generic_list(typ: type[list[T]]) -> type[T]:
"""
Extracts the item type of a list type.
@ -426,21 +379,21 @@ def unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
return rewrap_annotated_type(_unwrap_generic_list, typ)
def _unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
def _unwrap_generic_list(typ: type[list[T]]) -> type[T]:
"Extracts the item type of a list type (e.g. returns `T` for `List[T]`)."
(list_type,) = typing.get_args(typ) # unpack single tuple element
return list_type # type: ignore[no-any-return]
def is_generic_set(typ: object) -> TypeGuard[Type[set]]:
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
"True if the specified type is a generic set, i.e. `Set[T]`."
typ = unwrap_annotated_type(typ)
return typing.get_origin(typ) is set
def unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
def unwrap_generic_set(typ: type[set[T]]) -> type[T]:
"""
Extracts the item type of a set type.
@ -451,21 +404,21 @@ def unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
return rewrap_annotated_type(_unwrap_generic_set, typ)
def _unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
def _unwrap_generic_set(typ: type[set[T]]) -> type[T]:
"Extracts the item type of a set type (e.g. returns `T` for `Set[T]`)."
(set_type,) = typing.get_args(typ) # unpack single tuple element
return set_type # type: ignore[no-any-return]
def is_generic_dict(typ: object) -> TypeGuard[Type[dict]]:
def is_generic_dict(typ: object) -> TypeGuard[type[dict]]:
"True if the specified type is a generic dictionary, i.e. `Dict[KeyType, ValueType]`."
typ = unwrap_annotated_type(typ)
return typing.get_origin(typ) is dict
def unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]:
def unwrap_generic_dict(typ: type[dict[K, V]]) -> tuple[type[K], type[V]]:
"""
Extracts the key and value types of a dictionary type as a tuple.
@ -476,7 +429,7 @@ def unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]:
return _unwrap_generic_dict(unwrap_annotated_type(typ))
def _unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]:
def _unwrap_generic_dict(typ: type[dict[K, V]]) -> tuple[type[K], type[V]]:
"Extracts the key and value types of a dict type (e.g. returns (`K`, `V`) for `Dict[K, V]`)."
key_type, value_type = typing.get_args(typ)
@ -489,7 +442,7 @@ def is_type_annotated(typ: TypeLike) -> bool:
return getattr(typ, "__metadata__", None) is not None
def get_annotation(data_type: TypeLike, annotation_type: Type[T]) -> Optional[T]:
def get_annotation(data_type: TypeLike, annotation_type: type[T]) -> T | None:
"""
Returns the first annotation on a data type that matches the expected annotation type.
@ -518,7 +471,7 @@ def unwrap_annotated_type(typ: T) -> T:
return typ
def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S]) -> Type[T]:
def rewrap_annotated_type(transform: Callable[[type[S]], type[T]], typ: type[S]) -> type[T]:
"""
Un-boxes, transforms and re-boxes an optionally annotated type.
@ -542,7 +495,7 @@ def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S])
return transformed_type
def get_module_classes(module: types.ModuleType) -> List[type]:
def get_module_classes(module: types.ModuleType) -> list[type]:
"Returns all classes declared directly in a module."
def is_class_member(member: object) -> TypeGuard[type]:
@ -551,18 +504,11 @@ def get_module_classes(module: types.ModuleType) -> List[type]:
return [class_type for _, class_type in inspect.getmembers(module, is_class_member)]
if sys.version_info >= (3, 9):
def get_resolved_hints(typ: type) -> Dict[str, type]:
return typing.get_type_hints(typ, include_extras=True)
else:
def get_resolved_hints(typ: type) -> Dict[str, type]:
return typing.get_type_hints(typ)
def get_resolved_hints(typ: type) -> dict[str, type]:
return typing.get_type_hints(typ, include_extras=True)
def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
def get_class_properties(typ: type) -> Iterable[tuple[str, type | str]]:
"Returns all properties of a class."
if is_dataclass_type(typ):
@ -572,7 +518,7 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
return resolved_hints.items()
def get_class_property(typ: type, name: str) -> Optional[type | str]:
def get_class_property(typ: type, name: str) -> type | str | None:
"Looks up the annotated type of a property in a class by its property name."
for property_name, property_type in get_class_properties(typ):
@ -586,7 +532,7 @@ class _ROOT:
pass
def get_referenced_types(typ: TypeLike, module: Optional[types.ModuleType] = None) -> Set[type]:
def get_referenced_types(typ: TypeLike, module: types.ModuleType | None = None) -> set[type]:
"""
Extracts types directly or indirectly referenced by this type.
@ -610,10 +556,10 @@ class TypeCollector:
:param graph: The type dependency graph, linking types to types they depend on.
"""
graph: Dict[type, Set[type]]
graph: dict[type, set[type]]
@property
def references(self) -> Set[type]:
def references(self) -> set[type]:
"Types collected by the type collector."
dependencies = set()
@ -638,8 +584,8 @@ class TypeCollector:
def run(
self,
typ: TypeLike,
cls: Type[DataclassInstance],
module: Optional[types.ModuleType],
cls: type[DataclassInstance],
module: types.ModuleType | None,
) -> None:
"""
Extracts types indirectly referenced by this type.
@ -702,26 +648,17 @@ class TypeCollector:
for field in dataclass_fields(typ):
self.run(field.type, typ, context)
else:
for field_name, field_type in get_resolved_hints(typ).items():
for _field_name, field_type in get_resolved_hints(typ).items():
self.run(field_type, typ, context)
return
raise TypeError(f"expected: type-like; got: {typ}")
if sys.version_info >= (3, 10):
def get_signature(fn: Callable[..., Any]) -> inspect.Signature:
"Extracts the signature of a function."
def get_signature(fn: Callable[..., Any]) -> inspect.Signature:
"Extracts the signature of a function."
return inspect.signature(fn, eval_str=True)
else:
def get_signature(fn: Callable[..., Any]) -> inspect.Signature:
"Extracts the signature of a function."
return inspect.signature(fn)
return inspect.signature(fn, eval_str=True)
def is_reserved_property(name: str) -> bool:
@ -756,51 +693,20 @@ def create_module(name: str) -> types.ModuleType:
return module
if sys.version_info >= (3, 10):
def create_data_type(class_name: str, fields: list[tuple[str, type]]) -> type:
"""
Creates a new data-class type dynamically.
def create_data_type(class_name: str, fields: List[Tuple[str, type]]) -> type:
"""
Creates a new data-class type dynamically.
:param class_name: The name of new data-class type.
:param fields: A list of fields (and their type) that the new data-class type is expected to have.
:returns: The newly created data-class type.
"""
:param class_name: The name of new data-class type.
:param fields: A list of fields (and their type) that the new data-class type is expected to have.
:returns: The newly created data-class type.
"""
# has the `slots` parameter
return dataclasses.make_dataclass(class_name, fields, slots=True)
else:
def create_data_type(class_name: str, fields: List[Tuple[str, type]]) -> type:
"""
Creates a new data-class type dynamically.
:param class_name: The name of new data-class type.
:param fields: A list of fields (and their type) that the new data-class type is expected to have.
:returns: The newly created data-class type.
"""
cls = dataclasses.make_dataclass(class_name, fields)
cls_dict = dict(cls.__dict__)
field_names = tuple(field.name for field in dataclasses.fields(cls))
cls_dict["__slots__"] = field_names
for field_name in field_names:
cls_dict.pop(field_name, None)
cls_dict.pop("__dict__", None)
qualname = getattr(cls, "__qualname__", None)
cls = type(cls)(cls.__name__, (), cls_dict)
if qualname is not None:
cls.__qualname__ = qualname
return cls
# has the `slots` parameter
return dataclasses.make_dataclass(class_name, fields, slots=True)
def create_object(typ: Type[T]) -> T:
def create_object(typ: type[T]) -> T:
"Creates an instance of a type."
if issubclass(typ, Exception):
@ -811,11 +717,7 @@ def create_object(typ: Type[T]) -> T:
return object.__new__(typ)
if sys.version_info >= (3, 9):
TypeOrGeneric = Union[type, types.GenericAlias]
else:
TypeOrGeneric = object
TypeOrGeneric = Union[type, types.GenericAlias]
def is_generic_instance(obj: Any, typ: TypeLike) -> bool:
@ -885,7 +787,7 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool:
class RecursiveChecker:
_pred: Optional[Callable[[type, Any], bool]]
_pred: Callable[[type, Any], bool] | None
def __init__(self, pred: Callable[[type, Any], bool]) -> None:
"""
@ -997,9 +899,9 @@ def check_recursive(
obj: object,
/,
*,
pred: Optional[Callable[[type, Any], bool]] = None,
type_pred: Optional[Callable[[type], bool]] = None,
value_pred: Optional[Callable[[Any], bool]] = None,
pred: Callable[[type, Any], bool] | None = None,
type_pred: Callable[[type], bool] | None = None,
value_pred: Callable[[Any], bool] | None = None,
) -> bool:
"""
Checks if a predicate applies to all nested member properties of an object recursively.
@ -1015,7 +917,7 @@ def check_recursive(
if pred is not None:
raise TypeError("filter predicate not permitted when type and value predicates are present")
type_p: Callable[[Type[T]], bool] = type_pred
type_p: Callable[[type[T]], bool] = type_pred
value_p: Callable[[T], bool] = value_pred
pred = lambda typ, obj: not type_p(typ) or value_p(obj) # noqa: E731

View file

@ -11,13 +11,12 @@ Type-safe data interchange for Python data classes.
"""
import keyword
from typing import Optional
from .auxiliary import Alias
from .inspection import get_annotation
def python_field_to_json_property(python_id: str, python_type: Optional[object] = None) -> str:
def python_field_to_json_property(python_id: str, python_type: object | None = None) -> str:
"""
Map a Python field identifier to a JSON property name.

View file

@ -11,7 +11,7 @@ Type-safe data interchange for Python data classes.
"""
import typing
from typing import Any, Literal, Optional, Tuple, Union
from typing import Any, Literal, Union
from .auxiliary import _auxiliary_types
from .inspection import (
@ -39,7 +39,7 @@ class TypeFormatter:
def __init__(self, use_union_operator: bool = False) -> None:
self.use_union_operator = use_union_operator
def union_to_str(self, data_type_args: Tuple[TypeLike, ...]) -> str:
def union_to_str(self, data_type_args: tuple[TypeLike, ...]) -> str:
if self.use_union_operator:
return " | ".join(self.python_type_to_str(t) for t in data_type_args)
else:
@ -100,7 +100,7 @@ class TypeFormatter:
metadata = getattr(data_type, "__metadata__", None)
if metadata is not None:
# type is Annotated[T, ...]
metatuple: Tuple[Any, ...] = metadata
metatuple: tuple[Any, ...] = metadata
arg = typing.get_args(data_type)[0]
# check for auxiliary types with user-defined annotations
@ -110,7 +110,7 @@ class TypeFormatter:
if arg is not auxiliary_arg:
continue
auxiliary_metatuple: Optional[Tuple[Any, ...]] = getattr(auxiliary_type, "__metadata__", None)
auxiliary_metatuple: tuple[Any, ...] | None = getattr(auxiliary_type, "__metadata__", None)
if auxiliary_metatuple is None:
continue

View file

@ -21,24 +21,19 @@ import json
import types
import typing
import uuid
from collections.abc import Callable
from copy import deepcopy
from typing import (
Annotated,
Any,
Callable,
ClassVar,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
import jsonschema
from typing_extensions import Annotated
from . import docstring
from .auxiliary import (
@ -71,7 +66,7 @@ OBJECT_ENUM_EXPANSION_LIMIT = 4
T = TypeVar("T")
def get_class_docstrings(data_type: type) -> Tuple[Optional[str], Optional[str]]:
def get_class_docstrings(data_type: type) -> tuple[str | None, str | None]:
docstr = docstring.parse_type(data_type)
# check if class has a doc-string other than the auto-generated string assigned by @dataclass
@ -82,8 +77,8 @@ def get_class_docstrings(data_type: type) -> Tuple[Optional[str], Optional[str]]
def get_class_property_docstrings(
data_type: type, transform_fun: Optional[Callable[[type, str, str], str]] = None
) -> Dict[str, str]:
data_type: type, transform_fun: Callable[[type, str, str], str] | None = None
) -> dict[str, str]:
"""
Extracts the documentation strings associated with the properties of a composite type.
@ -120,7 +115,7 @@ def docstring_to_schema(data_type: type) -> Schema:
return schema
def id_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> str:
def id_from_ref(data_type: typing.ForwardRef | str | type) -> str:
"Extracts the name of a possibly forward-referenced type."
if isinstance(data_type, typing.ForwardRef):
@ -132,7 +127,7 @@ def id_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> str:
return data_type.__name__
def type_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> Tuple[str, type]:
def type_from_ref(data_type: typing.ForwardRef | str | type) -> tuple[str, type]:
"Creates a type from a forward reference."
if isinstance(data_type, typing.ForwardRef):
@ -148,16 +143,16 @@ def type_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> Tuple[str,
@dataclasses.dataclass
class TypeCatalogEntry:
schema: Optional[Schema]
schema: Schema | None
identifier: str
examples: Optional[JsonType] = None
examples: JsonType | None = None
class TypeCatalog:
"Maintains an association of well-known Python types to their JSON schema."
_by_type: Dict[TypeLike, TypeCatalogEntry]
_by_name: Dict[str, TypeCatalogEntry]
_by_type: dict[TypeLike, TypeCatalogEntry]
_by_name: dict[str, TypeCatalogEntry]
def __init__(self) -> None:
self._by_type = {}
@ -174,9 +169,9 @@ class TypeCatalog:
def add(
self,
data_type: TypeLike,
schema: Optional[Schema],
schema: Schema | None,
identifier: str,
examples: Optional[List[JsonType]] = None,
examples: list[JsonType] | None = None,
) -> None:
if isinstance(data_type, typing.ForwardRef):
raise TypeError("forward references cannot be used to register a type")
@ -202,17 +197,17 @@ class SchemaOptions:
definitions_path: str = "#/definitions/"
use_descriptions: bool = True
use_examples: bool = True
property_description_fun: Optional[Callable[[type, str, str], str]] = None
property_description_fun: Callable[[type, str, str], str] | None = None
class JsonSchemaGenerator:
"Creates a JSON schema with user-defined type definitions."
type_catalog: ClassVar[TypeCatalog] = TypeCatalog()
types_used: Dict[str, TypeLike]
types_used: dict[str, TypeLike]
options: SchemaOptions
def __init__(self, options: Optional[SchemaOptions] = None):
def __init__(self, options: SchemaOptions | None = None):
if options is None:
self.options = SchemaOptions()
else:
@ -244,13 +239,13 @@ class JsonSchemaGenerator:
def _(self, arg: MaxLength) -> Schema:
return {"maxLength": arg.value}
def _with_metadata(self, type_schema: Schema, metadata: Optional[Tuple[Any, ...]]) -> Schema:
def _with_metadata(self, type_schema: Schema, metadata: tuple[Any, ...] | None) -> Schema:
if metadata:
for m in metadata:
type_schema.update(self._metadata_to_schema(m))
return type_schema
def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: Optional[dict] = None) -> Optional[Schema]:
def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: dict | None = None) -> Schema | None:
"""
Returns the JSON schema associated with a simple, unrestricted type.
@ -314,7 +309,7 @@ class JsonSchemaGenerator:
self,
data_type: TypeLike,
force_expand: bool = False,
json_schema_extra: Optional[dict] = None,
json_schema_extra: dict | None = None,
) -> Schema:
common_info = {}
if json_schema_extra and "deprecated" in json_schema_extra:
@ -325,7 +320,7 @@ class JsonSchemaGenerator:
self,
data_type: TypeLike,
force_expand: bool = False,
json_schema_extra: Optional[dict] = None,
json_schema_extra: dict | None = None,
) -> Schema:
"""
Returns the JSON schema associated with a type.
@ -381,7 +376,7 @@ class JsonSchemaGenerator:
return {"$ref": f"{self.options.definitions_path}{identifier}"}
if is_type_enum(typ):
enum_type: Type[enum.Enum] = typ
enum_type: type[enum.Enum] = typ
value_types = enum_value_types(enum_type)
if len(value_types) != 1:
raise ValueError(
@ -496,8 +491,8 @@ class JsonSchemaGenerator:
members = dict(inspect.getmembers(typ, lambda a: not inspect.isroutine(a)))
property_docstrings = get_class_property_docstrings(typ, self.options.property_description_fun)
properties: Dict[str, Schema] = {}
required: List[str] = []
properties: dict[str, Schema] = {}
required: list[str] = []
for property_name, property_type in get_class_properties(typ):
# rename property if an alias name is specified
alias = get_annotation(property_type, Alias)
@ -530,16 +525,7 @@ class JsonSchemaGenerator:
# check if value can be directly represented in JSON
if isinstance(
def_value,
(
bool,
int,
float,
str,
enum.Enum,
datetime.datetime,
datetime.date,
datetime.time,
),
bool | int | float | str | enum.Enum | datetime.datetime | datetime.date | datetime.time,
):
property_def["default"] = object_to_json(def_value)
@ -587,7 +573,7 @@ class JsonSchemaGenerator:
return type_schema
def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> Tuple[Schema, Dict[str, Schema]]:
def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> tuple[Schema, dict[str, Schema]]:
"""
Returns the JSON schema associated with a type and any nested types.
@ -604,7 +590,7 @@ class JsonSchemaGenerator:
try:
type_schema = self.type_to_schema(data_type, force_expand=force_expand)
types_defined: Dict[str, Schema] = {}
types_defined: dict[str, Schema] = {}
while len(self.types_used) > len(types_defined):
# make a snapshot copy; original collection is going to be modified
types_undefined = {
@ -635,7 +621,7 @@ class Validator(enum.Enum):
def classdef_to_schema(
data_type: TypeLike,
options: Optional[SchemaOptions] = None,
options: SchemaOptions | None = None,
validator: Validator = Validator.Latest,
) -> Schema:
"""
@ -689,7 +675,7 @@ def print_schema(data_type: type) -> None:
print(json.dumps(s, indent=4))
def get_schema_identifier(data_type: type) -> Optional[str]:
def get_schema_identifier(data_type: type) -> str | None:
if data_type in JsonSchemaGenerator.type_catalog:
return JsonSchemaGenerator.type_catalog.get(data_type).identifier
else:
@ -698,9 +684,9 @@ def get_schema_identifier(data_type: type) -> Optional[str]:
def register_schema(
data_type: T,
schema: Optional[Schema] = None,
name: Optional[str] = None,
examples: Optional[List[JsonType]] = None,
schema: Schema | None = None,
name: str | None = None,
examples: list[JsonType] | None = None,
) -> T:
"""
Associates a type with a JSON schema definition.
@ -721,22 +707,22 @@ def register_schema(
@overload
def json_schema_type(cls: Type[T], /) -> Type[T]: ...
def json_schema_type(cls: type[T], /) -> type[T]: ...
@overload
def json_schema_type(cls: None, *, schema: Optional[Schema] = None) -> Callable[[Type[T]], Type[T]]: ...
def json_schema_type(cls: None, *, schema: Schema | None = None) -> Callable[[type[T]], type[T]]: ...
def json_schema_type(
cls: Optional[Type[T]] = None,
cls: type[T] | None = None,
*,
schema: Optional[Schema] = None,
examples: Optional[List[JsonType]] = None,
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
schema: Schema | None = None,
examples: list[JsonType] | None = None,
) -> type[T] | Callable[[type[T]], type[T]]:
"""Decorator to add user-defined schema definition to a class."""
def wrap(cls: Type[T]) -> Type[T]:
def wrap(cls: type[T]) -> type[T]:
return register_schema(cls, schema, examples=examples)
# see if decorator is used as @json_schema_type or @json_schema_type()

View file

@ -14,7 +14,7 @@ import inspect
import json
import sys
from types import ModuleType
from typing import Any, Optional, TextIO, TypeVar
from typing import Any, TextIO, TypeVar
from .core import JsonType
from .deserializer import create_deserializer
@ -42,7 +42,7 @@ def object_to_json(obj: Any) -> JsonType:
return generator.generate(obj)
def json_to_object(typ: TypeLike, data: JsonType, *, context: Optional[ModuleType] = None) -> object:
def json_to_object(typ: TypeLike, data: JsonType, *, context: ModuleType | None = None) -> object:
"""
Creates an object from a representation that has been de-serialized from JSON.

View file

@ -20,19 +20,13 @@ import ipaddress
import sys
import typing
import uuid
from collections.abc import Callable
from types import FunctionType, MethodType, ModuleType
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Literal,
NamedTuple,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
@ -133,7 +127,7 @@ class IPv6Serializer(Serializer[ipaddress.IPv6Address]):
class EnumSerializer(Serializer[enum.Enum]):
def generate(self, obj: enum.Enum) -> Union[int, str]:
def generate(self, obj: enum.Enum) -> int | str:
value = obj.value
if isinstance(value, int):
return value
@ -141,12 +135,12 @@ class EnumSerializer(Serializer[enum.Enum]):
class UntypedListSerializer(Serializer[list]):
def generate(self, obj: list) -> List[JsonType]:
def generate(self, obj: list) -> list[JsonType]:
return [object_to_json(item) for item in obj]
class UntypedDictSerializer(Serializer[dict]):
def generate(self, obj: dict) -> Dict[str, JsonType]:
def generate(self, obj: dict) -> dict[str, JsonType]:
if obj and isinstance(next(iter(obj.keys())), enum.Enum):
iterator = ((key.value, object_to_json(value)) for key, value in obj.items())
else:
@ -155,41 +149,41 @@ class UntypedDictSerializer(Serializer[dict]):
class UntypedSetSerializer(Serializer[set]):
def generate(self, obj: set) -> List[JsonType]:
def generate(self, obj: set) -> list[JsonType]:
return [object_to_json(item) for item in obj]
class UntypedTupleSerializer(Serializer[tuple]):
def generate(self, obj: tuple) -> List[JsonType]:
def generate(self, obj: tuple) -> list[JsonType]:
return [object_to_json(item) for item in obj]
class TypedCollectionSerializer(Serializer, Generic[T]):
generator: Serializer[T]
def __init__(self, item_type: Type[T], context: Optional[ModuleType]) -> None:
def __init__(self, item_type: type[T], context: ModuleType | None) -> None:
self.generator = _get_serializer(item_type, context)
class TypedListSerializer(TypedCollectionSerializer[T]):
def generate(self, obj: List[T]) -> List[JsonType]:
def generate(self, obj: list[T]) -> list[JsonType]:
return [self.generator.generate(item) for item in obj]
class TypedStringDictSerializer(TypedCollectionSerializer[T]):
def __init__(self, value_type: Type[T], context: Optional[ModuleType]) -> None:
def __init__(self, value_type: type[T], context: ModuleType | None) -> None:
super().__init__(value_type, context)
def generate(self, obj: Dict[str, T]) -> Dict[str, JsonType]:
def generate(self, obj: dict[str, T]) -> dict[str, JsonType]:
return {key: self.generator.generate(value) for key, value in obj.items()}
class TypedEnumDictSerializer(TypedCollectionSerializer[T]):
def __init__(
self,
key_type: Type[enum.Enum],
value_type: Type[T],
context: Optional[ModuleType],
key_type: type[enum.Enum],
value_type: type[T],
context: ModuleType | None,
) -> None:
super().__init__(value_type, context)
@ -203,22 +197,22 @@ class TypedEnumDictSerializer(TypedCollectionSerializer[T]):
if value_type is not str:
raise JsonTypeError("invalid enumeration key type, expected `enum.Enum` with string values")
def generate(self, obj: Dict[enum.Enum, T]) -> Dict[str, JsonType]:
def generate(self, obj: dict[enum.Enum, T]) -> dict[str, JsonType]:
return {key.value: self.generator.generate(value) for key, value in obj.items()}
class TypedSetSerializer(TypedCollectionSerializer[T]):
def generate(self, obj: Set[T]) -> JsonType:
def generate(self, obj: set[T]) -> JsonType:
return [self.generator.generate(item) for item in obj]
class TypedTupleSerializer(Serializer[tuple]):
item_generators: Tuple[Serializer, ...]
item_generators: tuple[Serializer, ...]
def __init__(self, item_types: Tuple[type, ...], context: Optional[ModuleType]) -> None:
def __init__(self, item_types: tuple[type, ...], context: ModuleType | None) -> None:
self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types)
def generate(self, obj: tuple) -> List[JsonType]:
def generate(self, obj: tuple) -> list[JsonType]:
return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj, strict=False)]
@ -250,16 +244,16 @@ class FieldSerializer(Generic[T]):
self.property_name = property_name
self.generator = generator
def generate_field(self, obj: object, object_dict: Dict[str, JsonType]) -> None:
def generate_field(self, obj: object, object_dict: dict[str, JsonType]) -> None:
value = getattr(obj, self.field_name)
if value is not None:
object_dict[self.property_name] = self.generator.generate(value)
class TypedClassSerializer(Serializer[T]):
property_generators: List[FieldSerializer]
property_generators: list[FieldSerializer]
def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None:
def __init__(self, class_type: type[T], context: ModuleType | None) -> None:
self.property_generators = [
FieldSerializer(
field_name,
@ -269,8 +263,8 @@ class TypedClassSerializer(Serializer[T]):
for field_name, field_type in get_class_properties(class_type)
]
def generate(self, obj: T) -> Dict[str, JsonType]:
object_dict: Dict[str, JsonType] = {}
def generate(self, obj: T) -> dict[str, JsonType]:
object_dict: dict[str, JsonType] = {}
for property_generator in self.property_generators:
property_generator.generate_field(obj, object_dict)
@ -278,12 +272,12 @@ class TypedClassSerializer(Serializer[T]):
class TypedNamedTupleSerializer(TypedClassSerializer[NamedTuple]):
def __init__(self, class_type: Type[NamedTuple], context: Optional[ModuleType]) -> None:
def __init__(self, class_type: type[NamedTuple], context: ModuleType | None) -> None:
super().__init__(class_type, context)
class DataclassSerializer(TypedClassSerializer[T]):
def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None:
def __init__(self, class_type: type[T], context: ModuleType | None) -> None:
super().__init__(class_type, context)
@ -295,7 +289,7 @@ class UnionSerializer(Serializer):
class LiteralSerializer(Serializer):
generator: Serializer
def __init__(self, values: Tuple[Any, ...], context: Optional[ModuleType]) -> None:
def __init__(self, values: tuple[Any, ...], context: ModuleType | None) -> None:
literal_type_tuple = tuple(type(value) for value in values)
literal_type_set = set(literal_type_tuple)
if len(literal_type_set) != 1:
@ -312,12 +306,12 @@ class LiteralSerializer(Serializer):
class UntypedNamedTupleSerializer(Serializer):
fields: Dict[str, str]
fields: dict[str, str]
def __init__(self, class_type: Type[NamedTuple]) -> None:
def __init__(self, class_type: type[NamedTuple]) -> None:
# named tuples are also instances of tuple
self.fields = {}
field_names: Tuple[str, ...] = class_type._fields
field_names: tuple[str, ...] = class_type._fields
for field_name in field_names:
self.fields[field_name] = python_field_to_json_property(field_name)
@ -351,7 +345,7 @@ class UntypedClassSerializer(Serializer):
return object_dict
def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Serializer:
def create_serializer(typ: TypeLike, context: ModuleType | None = None) -> Serializer:
"""
Creates a serializer engine to produce an object that can be directly converted into a JSON string.
@ -376,8 +370,8 @@ def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Se
return _get_serializer(typ, context)
def _get_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
if isinstance(typ, (str, typing.ForwardRef)):
def _get_serializer(typ: TypeLike, context: ModuleType | None) -> Serializer:
if isinstance(typ, str | typing.ForwardRef):
if context is None:
raise TypeError(f"missing context for evaluating type: {typ}")
@ -390,13 +384,13 @@ def _get_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
return _create_serializer(typ, context)
@functools.lru_cache(maxsize=None)
@functools.cache
def _fetch_serializer(typ: type) -> Serializer:
context = sys.modules[typ.__module__]
return _create_serializer(typ, context)
def _create_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
def _create_serializer(typ: TypeLike, context: ModuleType | None) -> Serializer:
# check for well-known types
if typ is type(None):
return NoneSerializer()

View file

@ -4,18 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Tuple, Type, TypeVar
from typing import Any, TypeVar
T = TypeVar("T")
class SlotsMeta(type):
def __new__(cls: Type[T], name: str, bases: Tuple[type, ...], ns: Dict[str, Any]) -> T:
def __new__(cls: type[T], name: str, bases: tuple[type, ...], ns: dict[str, Any]) -> T:
# caller may have already provided slots, in which case just retain them and keep going
slots: Tuple[str, ...] = ns.get("__slots__", ())
slots: tuple[str, ...] = ns.get("__slots__", ())
# add fields with type annotations to slots
annotations: Dict[str, Any] = ns.get("__annotations__", {})
annotations: dict[str, Any] = ns.get("__annotations__", {})
members = tuple(member for member in annotations.keys() if member not in slots)
# assign slots

View file

@ -10,14 +10,15 @@ Type-safe data interchange for Python data classes.
:see: https://github.com/hunyadi/strong_typing
"""
from typing import Callable, Dict, Iterable, List, Optional, Set, TypeVar
from collections.abc import Callable, Iterable
from typing import TypeVar
from .inspection import TypeCollector
T = TypeVar("T")
def topological_sort(graph: Dict[T, Set[T]]) -> List[T]:
def topological_sort(graph: dict[T, set[T]]) -> list[T]:
"""
Performs a topological sort of a graph.
@ -29,9 +30,9 @@ def topological_sort(graph: Dict[T, Set[T]]) -> List[T]:
"""
# empty list that will contain the sorted nodes (in reverse order)
ordered: List[T] = []
ordered: list[T] = []
seen: Dict[T, bool] = {}
seen: dict[T, bool] = {}
def _visit(n: T) -> None:
status = seen.get(n)
@ -57,8 +58,8 @@ def topological_sort(graph: Dict[T, Set[T]]) -> List[T]:
def type_topological_sort(
types: Iterable[type],
dependency_fn: Optional[Callable[[type], Iterable[type]]] = None,
) -> List[type]:
dependency_fn: Callable[[type], Iterable[type]] | None = None,
) -> list[type]:
"""
Performs a topological sort of a list of types.
@ -78,7 +79,7 @@ def type_topological_sort(
graph = collector.graph
if dependency_fn:
new_types: Set[type] = set()
new_types: set[type] = set()
for source_type, references in graph.items():
dependent_types = dependency_fn(source_type)
references.update(dependent_types)

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .ollama import get_distribution_template # noqa: F401

View file

@ -0,0 +1,39 @@
version: 2
distribution_spec:
description: Use (an external) Ollama server for running LLM inference
providers:
inference:
- remote::ollama
vector_io:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
files:
- inline::localfs
post_training:
- inline::huggingface
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::rag-runtime
- remote::model-context-protocol
- remote::wolfram-alpha
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -0,0 +1,168 @@
# Ollama Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} {{ model.doc_string }}`
{% endfor %}
{% endif %}
## Prerequisites
### Ollama Server
This distribution requires an external Ollama server to be running. You can install and run Ollama by following these steps:
1. **Install Ollama**: Download and install Ollama from [https://ollama.ai/](https://ollama.ai/)
2. **Start the Ollama server**:
```bash
ollama serve
```
By default, Ollama serves on `http://127.0.0.1:11434`
3. **Pull the required models**:
```bash
# Pull the inference model
ollama pull meta-llama/Llama-3.2-3B-Instruct
# Pull the embedding model
ollama pull all-minilm:latest
# (Optional) Pull the safety model for run-with-safety.yaml
ollama pull meta-llama/Llama-Guard-3-1B
```
## Supported Services
### Inference: Ollama
Uses an external Ollama server for running LLM inference. The server should be accessible at the URL specified in the `OLLAMA_URL` environment variable.
### Vector IO: FAISS
Provides vector storage capabilities using FAISS for embeddings and similarity search operations.
### Safety: Llama Guard (Optional)
When using the `run-with-safety.yaml` configuration, provides safety checks using Llama Guard models running on the Ollama server.
### Agents: Meta Reference
Provides agent execution capabilities using the meta-reference implementation.
### Post-Training: Hugging Face
Supports model fine-tuning using Hugging Face integration.
### Tool Runtime
Supports various external tools including:
- Brave Search
- Tavily Search
- RAG Runtime
- Model Context Protocol
- Wolfram Alpha
## Running Llama Stack with Ollama
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Via Conda
```bash
llama stack build --template ollama --image-type conda
llama stack run ./run.yaml \
--port 8321 \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
llama stack build --template ollama --image-type venv
llama stack run ./run.yaml \
--port 8321 \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Running with Safety
To enable safety checks, use the `run-with-safety.yaml` configuration:
```bash
llama stack run ./run-with-safety.yaml \
--port 8321 \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env SAFETY_MODEL=$SAFETY_MODEL
```
## Example Usage
Once your Llama Stack server is running with Ollama, you can interact with it using the Llama Stack client:
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:8321")
# Run inference
response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.2-3B-Instruct",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)
print(response.completion_message.content)
```
## Troubleshooting
### Common Issues
1. **Connection refused errors**: Ensure your Ollama server is running and accessible at the configured URL.
2. **Model not found errors**: Make sure you've pulled the required models using `ollama pull <model-name>`.
3. **Performance issues**: Consider using more powerful models or adjusting the Ollama server configuration for better performance.
### Logs
Check the Ollama server logs for any issues:
```bash
# Ollama logs are typically available in:
# - macOS: ~/Library/Logs/Ollama/
# - Linux: ~/.ollama/logs/
```

View file

@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::ollama"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"files": ["inline::localfs"],
"post_training": ["inline::huggingface"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::rag-runtime",
"remote::model-context-protocol",
"remote::wolfram-alpha",
],
}
name = "ollama"
inference_provider = Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=OllamaImplConfig.sample_run_config(),
)
vector_io_provider_faiss = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(
f"${{env.XDG_STATE_HOME:-~/.local/state}}/llama-stack/distributions/{name}"
),
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(
f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/distributions/{name}"
),
)
posttraining_provider = Provider(
provider_id="huggingface",
provider_type="inline::huggingface",
config=HuggingFacePostTrainingConfig.sample_run_config(
f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/distributions/{name}"
),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="ollama",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="ollama",
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="ollama",
provider_model_id="all-minilm:latest",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",
),
]
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Use (an external) Ollama server for running LLM inference",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider_faiss],
"files": [files_provider],
"post_training": [posttraining_provider],
},
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider_faiss],
"files": [files_provider],
"post_training": [posttraining_provider],
"safety": [
Provider(
provider_id="llama-guard",
provider_type="inline::llama-guard",
config={},
),
Provider(
provider_id="code-scanner",
provider_type="inline::code-scanner",
config={},
),
],
},
default_models=[
inference_model,
safety_model,
embedding_model,
],
default_shields=[
ShieldInput(
shield_id="${env.SAFETY_MODEL}",
provider_id="llama-guard",
),
ShieldInput(
shield_id="CodeScanner",
provider_id="code-scanner",
),
],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
"OLLAMA_URL": (
"http://127.0.0.1:11434",
"URL of the Ollama server",
),
"INFERENCE_MODEL": (
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model loaded into the Ollama server",
),
"SAFETY_MODEL": (
"meta-llama/Llama-Guard-3-1B",
"Safety model loaded into the Ollama server",
),
},
)

View file

@ -0,0 +1,158 @@
version: 2
image_name: ollama
apis:
- agents
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ollama
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
- provider_id: code-scanner
provider_type: inline::code-scanner
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama}/files_metadata.db
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:=}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: ollama
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: ollama
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ollama
provider_model_id: all-minilm:latest
model_type: embedding
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: llama-guard
- shield_id: CodeScanner
provider_id: code-scanner
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server:
port: 8321

View file

@ -0,0 +1,148 @@
version: 2
image_name: ollama
apis:
- agents
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ollama
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama}/files_metadata.db
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:=}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: ollama
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ollama
provider_model_id: all-minilm:latest
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server:
port: 8321