This commit is contained in:
Mustafa Elbehery 2025-07-24 23:57:19 +02:00 committed by GitHub
commit 172d578b20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
50 changed files with 5611 additions and 508 deletions

View file

@ -0,0 +1,172 @@
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# Ollama Distribution
The `llamastack/distribution-ollama` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| files | `inline::localfs` |
| inference | `remote::ollama` |
| post_training | `inline::huggingface` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
### Environment Variables
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
## Prerequisites
### Ollama Server
This distribution requires an external Ollama server to be running. You can install and run Ollama by following these steps:
1. **Install Ollama**: Download and install Ollama from [https://ollama.ai/](https://ollama.ai/)
2. **Start the Ollama server**:
```bash
ollama serve
```
By default, Ollama serves on `http://127.0.0.1:11434`
3. **Pull the required models**:
```bash
# Pull the inference model
ollama pull meta-llama/Llama-3.2-3B-Instruct
# Pull the embedding model
ollama pull all-minilm:latest
# (Optional) Pull the safety model for run-with-safety.yaml
ollama pull meta-llama/Llama-Guard-3-1B
```
## Supported Services
### Inference: Ollama
Uses an external Ollama server for running LLM inference. The server should be accessible at the URL specified in the `OLLAMA_URL` environment variable.
### Vector IO: FAISS
Provides vector storage capabilities using FAISS for embeddings and similarity search operations.
### Safety: Llama Guard (Optional)
When using the `run-with-safety.yaml` configuration, provides safety checks using Llama Guard models running on the Ollama server.
### Agents: Meta Reference
Provides agent execution capabilities using the meta-reference implementation.
### Post-Training: Hugging Face
Supports model fine-tuning using Hugging Face integration.
### Tool Runtime
Supports various external tools including:
- Brave Search
- Tavily Search
- RAG Runtime
- Model Context Protocol
- Wolfram Alpha
## Running Llama Stack with Ollama
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-ollama \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Via Conda
```bash
llama stack build --template ollama --image-type conda
llama stack run ./run.yaml \
--port 8321 \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
llama stack build --template ollama --image-type venv
llama stack run ./run.yaml \
--port 8321 \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Running with Safety
To enable safety checks, use the `run-with-safety.yaml` configuration:
```bash
llama stack run ./run-with-safety.yaml \
--port 8321 \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env SAFETY_MODEL=$SAFETY_MODEL
```
## Example Usage
Once your Llama Stack server is running with Ollama, you can interact with it using the Llama Stack client:
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:8321")
# Run inference
response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.2-3B-Instruct",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)
print(response.completion_message.content)
```
## Troubleshooting
### Common Issues
1. **Connection refused errors**: Ensure your Ollama server is running and accessible at the configured URL.
2. **Model not found errors**: Make sure you've pulled the required models using `ollama pull <model-name>`.
3. **Performance issues**: Consider using more powerful models or adjusting the Ollama server configuration for better performance.
### Logs
Check the Ollama server logs for any issues:
```bash
# Ollama logs are typically available in:
# - macOS: ~/Library/Logs/Ollama/
# - Linux: ~/.ollama/logs/
```

View file

@ -0,0 +1,191 @@
# XDG Base Directory Specification Compliance
Starting with version 0.2.14, Llama Stack follows the [XDG Base Directory Specification](https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html) for organizing configuration and data files. This provides better integration with modern desktop environments and allows for more flexible customization of where files are stored.
## Overview
The XDG Base Directory Specification defines standard locations for different types of application data:
- **Configuration files** (`XDG_CONFIG_HOME`): User-specific configuration files
- **Data files** (`XDG_DATA_HOME`): User-specific data files that should persist
- **Cache files** (`XDG_CACHE_HOME`): User-specific cache files
- **State files** (`XDG_STATE_HOME`): User-specific state files
## Directory Mapping
Llama Stack now uses the following XDG-compliant directory structure:
| Data Type | XDG Directory | Default Location | Description |
|-----------|---------------|------------------|-------------|
| Configuration | `XDG_CONFIG_HOME` | `~/.config/llama-stack` | Distribution configs, provider configs |
| Data | `XDG_DATA_HOME` | `~/.local/share/llama-stack` | Model checkpoints, persistent files |
| Cache | `XDG_CACHE_HOME` | `~/.cache/llama-stack` | Temporary cache files |
| State | `XDG_STATE_HOME` | `~/.local/state/llama-stack` | Runtime state, databases |
## Environment Variables
You can customize the locations by setting these environment variables:
```bash
# Override the base directories
export XDG_CONFIG_HOME="/custom/config/path"
export XDG_DATA_HOME="/custom/data/path"
export XDG_CACHE_HOME="/custom/cache/path"
export XDG_STATE_HOME="/custom/state/path"
# Or override specific Llama Stack directories
export SQLITE_STORE_DIR="/custom/database/path"
export FILES_STORAGE_DIR="/custom/files/path"
```
## Backwards Compatibility
Llama Stack maintains full backwards compatibility with existing installations:
1. **Legacy Environment Variable**: If `LLAMA_STACK_CONFIG_DIR` is set, it will be used for all directories
2. **Legacy Directory Detection**: If `~/.llama` exists and contains data, it will continue to be used
3. **Gradual Migration**: New installations use XDG paths, existing installations continue to work
## Migration Guide
### Automatic Migration
Use the built-in migration command to move from legacy `~/.llama` to XDG-compliant directories:
```bash
# Preview what would be migrated
llama migrate-xdg --dry-run
# Perform the migration
llama migrate-xdg
```
### Manual Migration
If you prefer to migrate manually, here's the mapping:
```bash
# Create XDG directories
mkdir -p ~/.config/llama-stack
mkdir -p ~/.local/share/llama-stack
mkdir -p ~/.local/state/llama-stack
# Move configuration files
mv ~/.llama/distributions ~/.config/llama-stack/
mv ~/.llama/providers.d ~/.config/llama-stack/
# Move data files
mv ~/.llama/checkpoints ~/.local/share/llama-stack/
# Move state files
mv ~/.llama/runtime ~/.local/state/llama-stack/
# Clean up empty legacy directory
rmdir ~/.llama
```
### Environment Variables in Configurations
Template configurations now use XDG-compliant environment variables:
```yaml
# Old format
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
# New format
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/registry.db
```
## Configuration Examples
### Using Custom XDG Directories
```bash
# Set custom XDG directories
export XDG_CONFIG_HOME="/opt/llama-stack/config"
export XDG_DATA_HOME="/opt/llama-stack/data"
export XDG_STATE_HOME="/opt/llama-stack/state"
# Start Llama Stack
llama stack run my-distribution.yaml
```
### Using Legacy Directory
```bash
# Continue using legacy directory
export LLAMA_STACK_CONFIG_DIR="/home/user/.llama"
# Start Llama Stack
llama stack run my-distribution.yaml
```
### Custom Database and File Locations
```bash
# Override specific directories
export SQLITE_STORE_DIR="/fast/ssd/llama-stack/databases"
export FILES_STORAGE_DIR="/large/disk/llama-stack/files"
# Start Llama Stack
llama stack run my-distribution.yaml
```
## Benefits of XDG Compliance
1. **Standards Compliance**: Follows established Linux/Unix conventions
2. **Better Organization**: Separates configuration, data, cache, and state files
3. **Flexibility**: Easy to customize storage locations
4. **Backup-Friendly**: Easier to backup just data files or just configuration
5. **Multi-User Support**: Better support for shared systems
6. **Cross-Platform**: Works consistently across different environments
## Template Updates
All distribution templates have been updated to use XDG-compliant paths:
- Database files use `XDG_STATE_HOME`
- Model checkpoints use `XDG_DATA_HOME`
- Configuration files use `XDG_CONFIG_HOME`
- Cache files use `XDG_CACHE_HOME`
## Troubleshooting
### Migration Issues
If you encounter issues during migration:
1. **Check Permissions**: Ensure you have write permissions to target directories
2. **Disk Space**: Verify sufficient disk space in target locations
3. **Existing Files**: Handle conflicts with existing files in target locations
### Environment Variable Conflicts
If you have multiple environment variables set:
1. `LLAMA_STACK_CONFIG_DIR` takes highest precedence
2. Individual `XDG_*` variables override defaults
3. Fallback to legacy `~/.llama` if it exists
4. Default to XDG standard paths for new installations
### Debugging Path Resolution
To see which paths Llama Stack is using:
```python
from llama_stack.distribution.utils.xdg_utils import (
get_llama_stack_config_dir,
get_llama_stack_data_dir,
get_llama_stack_state_dir,
)
print(f"Config: {get_llama_stack_config_dir()}")
print(f"Data: {get_llama_stack_data_dir()}")
print(f"State: {get_llama_stack_state_dir()}")
```
## Future Considerations
- Container deployments will continue to use `/app` or similar paths
- Cloud deployments may use provider-specific storage systems
- The XDG specification primarily applies to local development and single-user systems

View file

@ -16,10 +16,10 @@ Meta's reference implementation of an agent system that can use tools, access ve
```yaml
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/agents_store.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/responses_store.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/responses_store.db
```

View file

@ -15,7 +15,7 @@ Local filesystem-based dataset I/O provider for reading and writing datasets to
```yaml
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/localfs_datasetio.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/localfs_datasetio.db
```

View file

@ -15,7 +15,7 @@ HuggingFace datasets provider for accessing and managing datasets from the Huggi
```yaml
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/huggingface_datasetio.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/huggingface_datasetio.db
```

View file

@ -15,7 +15,7 @@ Meta's reference implementation of evaluation tasks with support for multiple la
```yaml
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/meta_reference_eval.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/meta_reference_eval.db
```

View file

@ -15,10 +15,10 @@ Local filesystem-based file storage provider for managing files and documents lo
## Sample Configuration
```yaml
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/dummy/files}
storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/files_metadata.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/files_metadata.db
```

View file

@ -11,14 +11,14 @@ Meta's reference implementation of telemetry and observability using OpenTelemet
| `otel_exporter_otlp_endpoint` | `str \| None` | No | | The OpenTelemetry collector endpoint URL (base URL for traces, metrics, and logs). If not set, the SDK will use OTEL_EXPORTER_OTLP_ENDPOINT environment variable. |
| `service_name` | `<class 'str'>` | No | | The service name to use for telemetry |
| `sinks` | `list[inline.telemetry.meta_reference.config.TelemetrySink` | No | [<TelemetrySink.CONSOLE: 'console'>, <TelemetrySink.SQLITE: 'sqlite'>] | List of telemetry sinks to enable (possible values: otel_trace, otel_metric, sqlite, console) |
| `sqlite_db_path` | `<class 'str'>` | No | ~/.llama/runtime/trace_store.db | The path to the SQLite database to use for storing traces |
| `sqlite_db_path` | `<class 'str'>` | No | ${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/runtime/trace_store.db | The path to the SQLite database to use for storing traces |
## Sample Configuration
```yaml
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/trace_store.db
sqlite_db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
```

View file

@ -44,7 +44,7 @@ more details about Faiss in general.
```yaml
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/faiss_store.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/faiss_store.db
```

View file

@ -15,7 +15,7 @@ Meta's reference implementation of a vector database.
```yaml
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/faiss_store.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/faiss_store.db
```

View file

@ -17,10 +17,10 @@ Please refer to the remote provider documentation.
## Sample Configuration
```yaml
db_path: ${env.MILVUS_DB_PATH:=~/.llama/dummy}/milvus.db
db_path: ${env.MILVUS_DB_PATH:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/milvus.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_registry.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/milvus_registry.db
```

View file

@ -55,7 +55,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
## Sample Configuration
```yaml
path: ${env.QDRANT_PATH:=~/.llama/~/.llama/dummy}/qdrant.db
path: ${env.QDRANT_PATH:=~/.llama/${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/qdrant.db
```

View file

@ -211,10 +211,10 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
## Sample Configuration
```yaml
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec_registry.db
```

View file

@ -16,10 +16,10 @@ Please refer to the sqlite-vec provider documentation.
## Sample Configuration
```yaml
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec_registry.db
```

View file

@ -126,7 +126,7 @@ uri: ${env.MILVUS_ENDPOINT}
token: ${env.MILVUS_TOKEN}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/milvus_remote_registry.db
```

View file

@ -52,7 +52,7 @@ user: ${env.PGVECTOR_USER}
password: ${env.PGVECTOR_PASSWORD}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/pgvector_registry.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/pgvector_registry.db
```

View file

@ -38,7 +38,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
```yaml
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/weaviate_registry.db
```

View file

@ -7,6 +7,7 @@
import argparse
from .download import Download
from .migrate_xdg import MigrateXDG
from .model import ModelParser
from .stack import StackParser
from .stack.utils import print_subcommand_description
@ -34,6 +35,7 @@ class LlamaCLIParser:
StackParser.create(subparsers)
Download.create(subparsers)
VerifyDownload.create(subparsers)
MigrateXDG.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -0,0 +1,168 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import shutil
import sys
from pathlib import Path
from llama_stack.distribution.utils.xdg_utils import (
get_llama_stack_config_dir,
get_llama_stack_data_dir,
get_llama_stack_state_dir,
)
from .subcommand import Subcommand
class MigrateXDG(Subcommand):
"""CLI command for migrating from legacy ~/.llama to XDG-compliant directories."""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"migrate-xdg",
prog="llama migrate-xdg",
description="Migrate from legacy ~/.llama to XDG-compliant directories",
formatter_class=argparse.RawTextHelpFormatter,
)
self.parser.add_argument(
"--dry-run", action="store_true", help="Show what would be done without actually moving files"
)
self.parser.set_defaults(func=self._run_migrate_xdg_cmd)
@staticmethod
def create(subparsers: argparse._SubParsersAction):
return MigrateXDG(subparsers)
def _run_migrate_xdg_cmd(self, args: argparse.Namespace) -> None:
"""Run the migrate-xdg command."""
if not migrate_to_xdg(dry_run=args.dry_run):
sys.exit(1)
def migrate_to_xdg(dry_run: bool = False) -> bool:
"""
Migrate from legacy ~/.llama to XDG-compliant directories.
Args:
dry_run: If True, only show what would be done without actually moving files
Returns:
bool: True if migration was successful or not needed, False otherwise
"""
legacy_path = Path.home() / ".llama"
if not legacy_path.exists():
print("No legacy ~/.llama directory found. Nothing to migrate.")
return True
# Check if we're already using XDG paths
config_dir = get_llama_stack_config_dir()
data_dir = get_llama_stack_data_dir()
state_dir = get_llama_stack_state_dir()
if str(config_dir) == str(legacy_path):
print("Already using legacy directory. No migration needed.")
return True
print(f"Found legacy directory at: {legacy_path}")
print("Will migrate to XDG-compliant directories:")
print(f" Config: {config_dir}")
print(f" Data: {data_dir}")
print(f" State: {state_dir}")
print()
# Define migration mapping
migrations = [
# (source_subdir, target_base_dir, description)
("distributions", config_dir, "Distribution configurations"),
("providers.d", config_dir, "External provider configurations"),
("checkpoints", data_dir, "Model checkpoints"),
("runtime", state_dir, "Runtime state files"),
]
# Check what needs to be migrated
items_to_migrate = []
for subdir, target_base, description in migrations:
source_path = legacy_path / subdir
if source_path.exists():
target_path = target_base / subdir
items_to_migrate.append((source_path, target_path, description))
if not items_to_migrate:
print("No items found to migrate.")
return True
print("Items to migrate:")
for source_path, target_path, description in items_to_migrate:
print(f" {description}: {source_path} -> {target_path}")
if dry_run:
print("\nDry run mode: No files will be moved.")
return True
# Ask for confirmation
response = input("\nDo you want to proceed with the migration? (y/N): ")
if response.lower() not in ["y", "yes"]:
print("Migration cancelled.")
return False
# Perform the migration
print("\nMigrating files...")
for source_path, target_path, description in items_to_migrate:
try:
# Create target directory if it doesn't exist
target_path.parent.mkdir(parents=True, exist_ok=True)
# Check if target already exists
if target_path.exists():
print(f" Warning: Target already exists: {target_path}")
print(f" Skipping {description}")
continue
# Move the directory
shutil.move(str(source_path), str(target_path))
print(f" Moved {description}: {source_path} -> {target_path}")
except Exception as e:
print(f" Error migrating {description}: {e}")
return False
# Check if legacy directory is now empty (except for hidden files)
remaining_items = [item for item in legacy_path.iterdir() if not item.name.startswith(".")]
if not remaining_items:
print(f"\nMigration complete! Legacy directory {legacy_path} is now empty.")
response = input("Remove empty legacy directory? (y/N): ")
if response.lower() in ["y", "yes"]:
try:
shutil.rmtree(legacy_path)
print(f"Removed empty legacy directory: {legacy_path}")
except Exception as e:
print(f"Could not remove legacy directory: {e}")
else:
print(f"\nMigration complete! Some items remain in legacy directory: {remaining_items}")
print("\nMigration successful!")
print("You may need to update any custom scripts or configurations that reference the old paths.")
return True
def main():
parser = argparse.ArgumentParser(description="Migrate from legacy ~/.llama to XDG-compliant directories")
parser.add_argument("--dry-run", action="store_true", help="Show what would be done without actually moving files")
args = parser.parse_args()
if not migrate_to_xdg(dry_run=args.dry_run):
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -7,12 +7,35 @@
import os
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
from .xdg_utils import (
get_llama_stack_config_dir,
get_llama_stack_data_dir,
get_llama_stack_state_dir,
)
# Base directory for all llama-stack configuration
# This now uses XDG-compliant paths with backwards compatibility
LLAMA_STACK_CONFIG_DIR = get_llama_stack_config_dir()
# Distribution configurations - stored in config directory
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
# Model checkpoints - stored in data directory (persistent data)
DEFAULT_CHECKPOINT_DIR = get_llama_stack_data_dir() / "checkpoints"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
# Runtime data - stored in state directory
RUNTIME_BASE_DIR = get_llama_stack_state_dir() / "runtime"
# External providers - stored in config directory
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"
# Legacy compatibility: if the legacy environment variable is set, use it for all paths
# This ensures that existing installations continue to work
legacy_config_dir = os.getenv("LLAMA_STACK_CONFIG_DIR")
if legacy_config_dir:
legacy_base = Path(legacy_config_dir)
LLAMA_STACK_CONFIG_DIR = legacy_base
DISTRIBS_BASE_DIR = legacy_base / "distributions"
DEFAULT_CHECKPOINT_DIR = legacy_base / "checkpoints"
RUNTIME_BASE_DIR = legacy_base / "runtime"
EXTERNAL_PROVIDERS_DIR = legacy_base / "providers.d"

View file

@ -0,0 +1,216 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pathlib import Path
def get_xdg_config_home() -> Path:
"""
Get the XDG config home directory.
Returns:
Path: XDG_CONFIG_HOME if set, otherwise ~/.config
"""
return Path(os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")))
def get_xdg_data_home() -> Path:
"""
Get the XDG data home directory.
Returns:
Path: XDG_DATA_HOME if set, otherwise ~/.local/share
"""
return Path(os.environ.get("XDG_DATA_HOME", os.path.expanduser("~/.local/share")))
def get_xdg_cache_home() -> Path:
"""
Get the XDG cache home directory.
Returns:
Path: XDG_CACHE_HOME if set, otherwise ~/.cache
"""
return Path(os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")))
def get_xdg_state_home() -> Path:
"""
Get the XDG state home directory.
Returns:
Path: XDG_STATE_HOME if set, otherwise ~/.local/state
"""
return Path(os.environ.get("XDG_STATE_HOME", os.path.expanduser("~/.local/state")))
def get_llama_stack_config_dir() -> Path:
"""
Get the llama-stack configuration directory.
This function provides backwards compatibility by checking for the legacy
LLAMA_STACK_CONFIG_DIR environment variable first, then falling back to
XDG-compliant paths.
Returns:
Path: Configuration directory for llama-stack
"""
# Check for legacy environment variable first for backwards compatibility
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
if legacy_dir:
return Path(legacy_dir)
# Check if legacy ~/.llama directory exists and contains data
legacy_path = Path.home() / ".llama"
if legacy_path.exists() and any(legacy_path.iterdir()):
return legacy_path
# Use XDG-compliant path
return get_xdg_config_home() / "llama-stack"
def get_llama_stack_data_dir() -> Path:
"""
Get the llama-stack data directory.
This is used for persistent data like model checkpoints.
Returns:
Path: Data directory for llama-stack
"""
# Check for legacy environment variable first for backwards compatibility
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
if legacy_dir:
return Path(legacy_dir)
# Check if legacy ~/.llama directory exists and contains data
legacy_path = Path.home() / ".llama"
if legacy_path.exists() and any(legacy_path.iterdir()):
return legacy_path
# Use XDG-compliant path
return get_xdg_data_home() / "llama-stack"
def get_llama_stack_cache_dir() -> Path:
"""
Get the llama-stack cache directory.
This is used for temporary/cache data.
Returns:
Path: Cache directory for llama-stack
"""
# Check for legacy environment variable first for backwards compatibility
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
if legacy_dir:
return Path(legacy_dir)
# Check if legacy ~/.llama directory exists and contains data
legacy_path = Path.home() / ".llama"
if legacy_path.exists() and any(legacy_path.iterdir()):
return legacy_path
# Use XDG-compliant path
return get_xdg_cache_home() / "llama-stack"
def get_llama_stack_state_dir() -> Path:
"""
Get the llama-stack state directory.
This is used for runtime state data.
Returns:
Path: State directory for llama-stack
"""
# Check for legacy environment variable first for backwards compatibility
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
if legacy_dir:
return Path(legacy_dir)
# Check if legacy ~/.llama directory exists and contains data
legacy_path = Path.home() / ".llama"
if legacy_path.exists() and any(legacy_path.iterdir()):
return legacy_path
# Use XDG-compliant path
return get_xdg_state_home() / "llama-stack"
def get_xdg_compliant_path(path_type: str, subdirectory: str | None = None, legacy_fallback: bool = True) -> Path:
"""
Get an XDG-compliant path for a given type.
Args:
path_type: Type of path ('config', 'data', 'cache', 'state')
subdirectory: Optional subdirectory within the base path
legacy_fallback: Whether to check for legacy ~/.llama directory
Returns:
Path: XDG-compliant path
Raises:
ValueError: If path_type is not recognized
"""
path_map = {
"config": get_llama_stack_config_dir,
"data": get_llama_stack_data_dir,
"cache": get_llama_stack_cache_dir,
"state": get_llama_stack_state_dir,
}
if path_type not in path_map:
raise ValueError(f"Unknown path type: {path_type}. Must be one of: {list(path_map.keys())}")
base_path = path_map[path_type]()
if subdirectory:
return base_path / subdirectory
return base_path
def migrate_legacy_directory() -> bool:
"""
Migrate from legacy ~/.llama directory to XDG-compliant directories.
This function helps users migrate their existing data to the new
XDG-compliant structure.
Returns:
bool: True if migration was successful or not needed, False otherwise
"""
legacy_path = Path.home() / ".llama"
if not legacy_path.exists():
return True # No migration needed
print(f"Found legacy directory at {legacy_path}")
print("Consider migrating to XDG-compliant directories:")
print(f" Config: {get_llama_stack_config_dir()}")
print(f" Data: {get_llama_stack_data_dir()}")
print(f" Cache: {get_llama_stack_cache_dir()}")
print(f" State: {get_llama_stack_state_dir()}")
print("Migration can be done by moving the appropriate subdirectories.")
return True
def ensure_directory_exists(path: Path) -> None:
"""
Ensure a directory exists, creating it if necessary.
Args:
path: Path to the directory
"""
path.mkdir(parents=True, exist_ok=True)

View file

@ -4,6 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64

View file

@ -12,23 +12,12 @@ Type-safe data interchange for Python data classes.
import dataclasses
import sys
from collections.abc import Callable
from dataclasses import is_dataclass
from typing import Callable, Dict, Optional, Type, TypeVar, Union, overload
if sys.version_info >= (3, 9):
from typing import Annotated as Annotated
else:
from typing_extensions import Annotated as Annotated
if sys.version_info >= (3, 10):
from typing import TypeAlias as TypeAlias
else:
from typing_extensions import TypeAlias as TypeAlias
if sys.version_info >= (3, 11):
from typing import dataclass_transform as dataclass_transform
else:
from typing_extensions import dataclass_transform as dataclass_transform
from typing import Annotated as Annotated
from typing import TypeAlias as TypeAlias
from typing import TypeVar, overload
from typing import dataclass_transform as dataclass_transform
T = TypeVar("T")
@ -56,17 +45,17 @@ class CompactDataClass:
@overload
def typeannotation(cls: Type[T], /) -> Type[T]: ...
def typeannotation(cls: type[T], /) -> type[T]: ...
@overload
def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[Type[T]], Type[T]]: ...
def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[type[T]], type[T]]: ...
@dataclass_transform(eq_default=True, order_default=False)
def typeannotation(
cls: Optional[Type[T]] = None, *, eq: bool = True, order: bool = False
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
cls: type[T] | None = None, *, eq: bool = True, order: bool = False
) -> type[T] | Callable[[type[T]], type[T]]:
"""
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
@ -76,7 +65,7 @@ def typeannotation(
:returns: A data-class type, or a wrapper for data-class types.
"""
def wrap(cls: Type[T]) -> Type[T]:
def wrap(cls: type[T]) -> type[T]:
# mypy fails to equate bound-y functions (first argument interpreted as
# the bound object) with class methods, hence the `ignore` directive.
cls.__repr__ = _compact_dataclass_repr # type: ignore[method-assign]
@ -179,41 +168,41 @@ class SpecialConversion:
"Indicates that the annotated type is subject to custom conversion rules."
int8: TypeAlias = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)]
int16: TypeAlias = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)]
int32: TypeAlias = Annotated[
type int8 = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)]
type int16 = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)]
type int32 = Annotated[
int,
Signed(True),
Storage(4),
IntegerRange(-2147483648, 2147483647),
]
int64: TypeAlias = Annotated[
type int64 = Annotated[
int,
Signed(True),
Storage(8),
IntegerRange(-9223372036854775808, 9223372036854775807),
]
uint8: TypeAlias = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)]
uint16: TypeAlias = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)]
uint32: TypeAlias = Annotated[
type uint8 = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)]
type uint16 = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)]
type uint32 = Annotated[
int,
Signed(False),
Storage(4),
IntegerRange(0, 4294967295),
]
uint64: TypeAlias = Annotated[
type uint64 = Annotated[
int,
Signed(False),
Storage(8),
IntegerRange(0, 18446744073709551615),
]
float32: TypeAlias = Annotated[float, Storage(4)]
float64: TypeAlias = Annotated[float, Storage(8)]
type float32 = Annotated[float, Storage(4)]
type float64 = Annotated[float, Storage(8)]
# maps globals of type Annotated[T, ...] defined in this module to their string names
_auxiliary_types: Dict[object, str] = {}
_auxiliary_types: dict[object, str] = {}
module = sys.modules[__name__]
for var in dir(module):
typ = getattr(module, var)
@ -222,7 +211,7 @@ for var in dir(module):
_auxiliary_types[typ] = var
def get_auxiliary_format(data_type: object) -> Optional[str]:
def get_auxiliary_format(data_type: object) -> str | None:
"Returns the JSON format string corresponding to an auxiliary type."
return _auxiliary_types.get(data_type)

View file

@ -12,12 +12,11 @@ import enum
import ipaddress
import math
import re
import sys
import types
import typing
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Literal, TypeVar, Union
from .auxiliary import (
Alias,
@ -40,57 +39,57 @@ T = TypeVar("T")
@dataclass
class JsonSchemaNode:
title: Optional[str]
description: Optional[str]
title: str | None
description: str | None
@dataclass
class JsonSchemaType(JsonSchemaNode):
type: str
format: Optional[str]
format: str | None
@dataclass
class JsonSchemaBoolean(JsonSchemaType):
type: Literal["boolean"]
const: Optional[bool]
default: Optional[bool]
examples: Optional[List[bool]]
const: bool | None
default: bool | None
examples: list[bool] | None
@dataclass
class JsonSchemaInteger(JsonSchemaType):
type: Literal["integer"]
const: Optional[int]
default: Optional[int]
examples: Optional[List[int]]
enum: Optional[List[int]]
minimum: Optional[int]
maximum: Optional[int]
const: int | None
default: int | None
examples: list[int] | None
enum: list[int] | None
minimum: int | None
maximum: int | None
@dataclass
class JsonSchemaNumber(JsonSchemaType):
type: Literal["number"]
const: Optional[float]
default: Optional[float]
examples: Optional[List[float]]
minimum: Optional[float]
maximum: Optional[float]
exclusiveMinimum: Optional[float]
exclusiveMaximum: Optional[float]
multipleOf: Optional[float]
const: float | None
default: float | None
examples: list[float] | None
minimum: float | None
maximum: float | None
exclusiveMinimum: float | None
exclusiveMaximum: float | None
multipleOf: float | None
@dataclass
class JsonSchemaString(JsonSchemaType):
type: Literal["string"]
const: Optional[str]
default: Optional[str]
examples: Optional[List[str]]
enum: Optional[List[str]]
minLength: Optional[int]
maxLength: Optional[int]
const: str | None
default: str | None
examples: list[str] | None
enum: list[str] | None
minLength: int | None
maxLength: int | None
@dataclass
@ -102,9 +101,9 @@ class JsonSchemaArray(JsonSchemaType):
@dataclass
class JsonSchemaObject(JsonSchemaType):
type: Literal["object"]
properties: Optional[Dict[str, "JsonSchemaAny"]]
additionalProperties: Optional[bool]
required: Optional[List[str]]
properties: dict[str, "JsonSchemaAny"] | None
additionalProperties: bool | None
required: list[str] | None
@dataclass
@ -114,24 +113,24 @@ class JsonSchemaRef(JsonSchemaNode):
@dataclass
class JsonSchemaAllOf(JsonSchemaNode):
allOf: List["JsonSchemaAny"]
allOf: list["JsonSchemaAny"]
@dataclass
class JsonSchemaAnyOf(JsonSchemaNode):
anyOf: List["JsonSchemaAny"]
anyOf: list["JsonSchemaAny"]
@dataclass
class Discriminator:
propertyName: str
mapping: Dict[str, str]
mapping: dict[str, str]
@dataclass
class JsonSchemaOneOf(JsonSchemaNode):
oneOf: List["JsonSchemaAny"]
discriminator: Optional[Discriminator]
oneOf: list["JsonSchemaAny"]
discriminator: Discriminator | None
JsonSchemaAny = Union[
@ -149,10 +148,10 @@ JsonSchemaAny = Union[
@dataclass
class JsonSchemaTopLevelObject(JsonSchemaObject):
schema: Annotated[str, Alias("$schema")]
definitions: Optional[Dict[str, JsonSchemaAny]]
definitions: dict[str, JsonSchemaAny] | None
def integer_range_to_type(min_value: float, max_value: float) -> type:
def integer_range_to_type(min_value: float, max_value: float) -> Any:
if min_value >= -(2**15) and max_value < 2**15:
return int16
elif min_value >= -(2**31) and max_value < 2**31:
@ -173,11 +172,11 @@ def enum_safe_name(name: str) -> str:
def enum_values_to_type(
module: types.ModuleType,
name: str,
values: Dict[str, Any],
title: Optional[str] = None,
description: Optional[str] = None,
) -> Type[enum.Enum]:
enum_class: Type[enum.Enum] = enum.Enum(name, values) # type: ignore
values: dict[str, Any],
title: str | None = None,
description: str | None = None,
) -> type[enum.Enum]:
enum_class: type[enum.Enum] = enum.Enum(name, values) # type: ignore
# assign the newly created type to the same module where the defining class is
enum_class.__module__ = module.__name__
@ -330,7 +329,7 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode
type_def = node_to_typedef(module, context, node.items)
if type_def.default is not dataclasses.MISSING:
raise TypeError("disallowed: `default` for array element type")
list_type = List[(type_def.type,)] # type: ignore
list_type = list[(type_def.type,)] # type: ignore
return TypeDef(list_type, dataclasses.MISSING)
elif isinstance(node, JsonSchemaObject):
@ -344,8 +343,8 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode
class_name = context
fields: List[Tuple[str, Any, dataclasses.Field]] = []
params: Dict[str, DocstringParam] = {}
fields: list[tuple[str, Any, dataclasses.Field]] = []
params: dict[str, DocstringParam] = {}
for prop_name, prop_node in node.properties.items():
type_def = node_to_typedef(module, f"{class_name}__{prop_name}", prop_node)
if prop_name in required:
@ -358,10 +357,7 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode
params[prop_name] = DocstringParam(prop_name, prop_desc)
fields.sort(key=lambda t: t[2].default is not dataclasses.MISSING)
if sys.version_info >= (3, 12):
class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__)
else:
class_type = dataclasses.make_dataclass(class_name, fields, namespace={"__module__": module.__name__})
class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__)
class_type.__doc__ = str(
Docstring(
short_description=node.title,
@ -388,7 +384,7 @@ class SchemaFlatteningOptions:
recursive: bool = False
def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions] = None) -> Schema:
def flatten_schema(schema: Schema, *, options: SchemaFlatteningOptions | None = None) -> Schema:
top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema))
flattener = SchemaFlattener(options)
obj = flattener.flatten(top_node)
@ -398,7 +394,7 @@ def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions]
class SchemaFlattener:
options: SchemaFlatteningOptions
def __init__(self, options: Optional[SchemaFlatteningOptions] = None) -> None:
def __init__(self, options: SchemaFlatteningOptions | None = None) -> None:
self.options = options or SchemaFlatteningOptions()
def flatten(self, source_node: JsonSchemaObject) -> JsonSchemaObject:
@ -406,10 +402,10 @@ class SchemaFlattener:
return source_node
source_props = source_node.properties or {}
target_props: Dict[str, JsonSchemaAny] = {}
target_props: dict[str, JsonSchemaAny] = {}
source_reqs = source_node.required or []
target_reqs: List[str] = []
target_reqs: list[str] = []
for name, prop in source_props.items():
if not isinstance(prop, JsonSchemaObject):

View file

@ -10,7 +10,7 @@ Type-safe data interchange for Python data classes.
:see: https://github.com/hunyadi/strong_typing
"""
from typing import Dict, List, Union
from typing import Union
class JsonObject:
@ -28,8 +28,8 @@ JsonType = Union[
int,
float,
str,
Dict[str, "JsonType"],
List["JsonType"],
dict[str, "JsonType"],
list["JsonType"],
]
# a JSON type that cannot contain `null` values
@ -38,9 +38,9 @@ StrictJsonType = Union[
int,
float,
str,
Dict[str, "StrictJsonType"],
List["StrictJsonType"],
dict[str, "StrictJsonType"],
list["StrictJsonType"],
]
# a meta-type that captures the object type in a JSON schema
Schema = Dict[str, JsonType]
Schema = dict[str, JsonType]

View file

@ -20,19 +20,14 @@ import ipaddress
import sys
import typing
import uuid
from collections.abc import Callable
from types import ModuleType
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Literal,
NamedTuple,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
@ -70,7 +65,7 @@ V = TypeVar("V")
class Deserializer(abc.ABC, Generic[T]):
"Parses a JSON value into a Python type."
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
"""
Creates auxiliary parsers that this parser is depending on.
@ -203,19 +198,19 @@ class IPv6Deserializer(Deserializer[ipaddress.IPv6Address]):
return ipaddress.IPv6Address(data)
class ListDeserializer(Deserializer[List[T]]):
class ListDeserializer(Deserializer[list[T]]):
"Recursively de-serializes a JSON array into a Python `list`."
item_type: Type[T]
item_type: type[T]
item_parser: Deserializer
def __init__(self, item_type: Type[T]) -> None:
def __init__(self, item_type: type[T]) -> None:
self.item_type = item_type
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.item_parser = _get_deserializer(self.item_type, context)
def parse(self, data: JsonType) -> List[T]:
def parse(self, data: JsonType) -> list[T]:
if not isinstance(data, list):
type_name = python_type_to_str(self.item_type)
raise JsonTypeError(f"type `List[{type_name}]` expects JSON `array` data but instead received: {data}")
@ -223,19 +218,19 @@ class ListDeserializer(Deserializer[List[T]]):
return [self.item_parser.parse(item) for item in data]
class DictDeserializer(Deserializer[Dict[K, V]]):
class DictDeserializer(Deserializer[dict[K, V]]):
"Recursively de-serializes a JSON object into a Python `dict`."
key_type: Type[K]
value_type: Type[V]
key_type: type[K]
value_type: type[V]
value_parser: Deserializer[V]
def __init__(self, key_type: Type[K], value_type: Type[V]) -> None:
def __init__(self, key_type: type[K], value_type: type[V]) -> None:
self.key_type = key_type
self.value_type = value_type
self._check_key_type()
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.value_parser = _get_deserializer(self.value_type, context)
def _check_key_type(self) -> None:
@ -264,48 +259,48 @@ class DictDeserializer(Deserializer[Dict[K, V]]):
value_type_name = python_type_to_str(self.value_type)
return f"Dict[{key_type_name}, {value_type_name}]"
def parse(self, data: JsonType) -> Dict[K, V]:
def parse(self, data: JsonType) -> dict[K, V]:
if not isinstance(data, dict):
raise JsonTypeError(
f"`type `{self.container_type}` expects JSON `object` data but instead received: {data}"
)
return dict(
(self.key_type(key), self.value_parser.parse(value)) # type: ignore[call-arg]
return {
self.key_type(key): self.value_parser.parse(value) # type: ignore[call-arg]
for key, value in data.items()
)
}
class SetDeserializer(Deserializer[Set[T]]):
class SetDeserializer(Deserializer[set[T]]):
"Recursively de-serializes a JSON list into a Python `set`."
member_type: Type[T]
member_type: type[T]
member_parser: Deserializer
def __init__(self, member_type: Type[T]) -> None:
def __init__(self, member_type: type[T]) -> None:
self.member_type = member_type
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.member_parser = _get_deserializer(self.member_type, context)
def parse(self, data: JsonType) -> Set[T]:
def parse(self, data: JsonType) -> set[T]:
if not isinstance(data, list):
type_name = python_type_to_str(self.member_type)
raise JsonTypeError(f"type `Set[{type_name}]` expects JSON `array` data but instead received: {data}")
return set(self.member_parser.parse(item) for item in data)
return {self.member_parser.parse(item) for item in data}
class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
class TupleDeserializer(Deserializer[tuple[Any, ...]]):
"Recursively de-serializes a JSON list into a Python `tuple`."
item_types: Tuple[Type[Any], ...]
item_parsers: Tuple[Deserializer[Any], ...]
item_types: tuple[type[Any], ...]
item_parsers: tuple[Deserializer[Any], ...]
def __init__(self, item_types: Tuple[Type[Any], ...]) -> None:
def __init__(self, item_types: tuple[type[Any], ...]) -> None:
self.item_types = item_types
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.item_parsers = tuple(_get_deserializer(item_type, context) for item_type in self.item_types)
@property
@ -313,7 +308,7 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
type_names = ", ".join(python_type_to_str(item_type) for item_type in self.item_types)
return f"Tuple[{type_names}]"
def parse(self, data: JsonType) -> Tuple[Any, ...]:
def parse(self, data: JsonType) -> tuple[Any, ...]:
if not isinstance(data, list) or len(data) != len(self.item_parsers):
if not isinstance(data, list):
raise JsonTypeError(
@ -331,13 +326,13 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
class UnionDeserializer(Deserializer):
"De-serializes a JSON value (of any type) into a Python union type."
member_types: Tuple[type, ...]
member_parsers: Tuple[Deserializer, ...]
member_types: tuple[type, ...]
member_parsers: tuple[Deserializer, ...]
def __init__(self, member_types: Tuple[type, ...]) -> None:
def __init__(self, member_types: tuple[type, ...]) -> None:
self.member_types = member_types
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.member_parsers = tuple(_get_deserializer(member_type, context) for member_type in self.member_types)
def parse(self, data: JsonType) -> Any:
@ -354,15 +349,15 @@ class UnionDeserializer(Deserializer):
raise JsonKeyError(f"type `Union[{type_names}]` could not be instantiated from: {data}")
def get_literal_properties(typ: type) -> Set[str]:
def get_literal_properties(typ: type) -> set[str]:
"Returns the names of all properties in a class that are of a literal type."
return set(
return {
property_name for property_name, property_type in get_class_properties(typ) if is_type_literal(property_type)
)
}
def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]:
def get_discriminating_properties(types: tuple[type, ...]) -> set[str]:
"Returns a set of properties with literal type that are common across all specified classes."
if not types or not all(isinstance(typ, type) for typ in types):
@ -378,15 +373,15 @@ def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]:
class TaggedUnionDeserializer(Deserializer):
"De-serializes a JSON value with one or more disambiguating properties into a Python union type."
member_types: Tuple[type, ...]
disambiguating_properties: Set[str]
member_parsers: Dict[Tuple[str, Any], Deserializer]
member_types: tuple[type, ...]
disambiguating_properties: set[str]
member_parsers: dict[tuple[str, Any], Deserializer]
def __init__(self, member_types: Tuple[type, ...]) -> None:
def __init__(self, member_types: tuple[type, ...]) -> None:
self.member_types = member_types
self.disambiguating_properties = get_discriminating_properties(member_types)
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
self.member_parsers = {}
for member_type in self.member_types:
for property_name in self.disambiguating_properties:
@ -435,13 +430,13 @@ class TaggedUnionDeserializer(Deserializer):
class LiteralDeserializer(Deserializer):
"De-serializes a JSON value into a Python literal type."
values: Tuple[Any, ...]
values: tuple[Any, ...]
parser: Deserializer
def __init__(self, values: Tuple[Any, ...]) -> None:
def __init__(self, values: tuple[Any, ...]) -> None:
self.values = values
def build(self, context: Optional[ModuleType]) -> None:
def build(self, context: ModuleType | None) -> None:
literal_type_tuple = tuple(type(value) for value in self.values)
literal_type_set = set(literal_type_tuple)
if len(literal_type_set) != 1:
@ -464,9 +459,9 @@ class LiteralDeserializer(Deserializer):
class EnumDeserializer(Deserializer[E]):
"Returns an enumeration instance based on the enumeration value read from a JSON value."
enum_type: Type[E]
enum_type: type[E]
def __init__(self, enum_type: Type[E]) -> None:
def __init__(self, enum_type: type[E]) -> None:
self.enum_type = enum_type
def parse(self, data: JsonType) -> E:
@ -504,13 +499,13 @@ class FieldDeserializer(abc.ABC, Generic[T, R]):
self.parser = parser
@abc.abstractmethod
def parse_field(self, data: Dict[str, JsonType]) -> R: ...
def parse_field(self, data: dict[str, JsonType]) -> R: ...
class RequiredFieldDeserializer(FieldDeserializer[T, T]):
"Deserializes a JSON property into a mandatory Python object field."
def parse_field(self, data: Dict[str, JsonType]) -> T:
def parse_field(self, data: dict[str, JsonType]) -> T:
if self.property_name not in data:
raise JsonKeyError(f"missing required property `{self.property_name}` from JSON object: {data}")
@ -520,7 +515,7 @@ class RequiredFieldDeserializer(FieldDeserializer[T, T]):
class OptionalFieldDeserializer(FieldDeserializer[T, Optional[T]]):
"Deserializes a JSON property into an optional Python object field with a default value of `None`."
def parse_field(self, data: Dict[str, JsonType]) -> Optional[T]:
def parse_field(self, data: dict[str, JsonType]) -> T | None:
value = data.get(self.property_name)
if value is not None:
return self.parser.parse(value)
@ -543,7 +538,7 @@ class DefaultFieldDeserializer(FieldDeserializer[T, T]):
super().__init__(property_name, field_name, parser)
self.default_value = default_value
def parse_field(self, data: Dict[str, JsonType]) -> T:
def parse_field(self, data: dict[str, JsonType]) -> T:
value = data.get(self.property_name)
if value is not None:
return self.parser.parse(value)
@ -566,7 +561,7 @@ class DefaultFactoryFieldDeserializer(FieldDeserializer[T, T]):
super().__init__(property_name, field_name, parser)
self.default_factory = default_factory
def parse_field(self, data: Dict[str, JsonType]) -> T:
def parse_field(self, data: dict[str, JsonType]) -> T:
value = data.get(self.property_name)
if value is not None:
return self.parser.parse(value)
@ -578,22 +573,22 @@ class ClassDeserializer(Deserializer[T]):
"Base class for de-serializing class-like types such as data classes, named tuples and regular classes."
class_type: type
property_parsers: List[FieldDeserializer]
property_fields: Set[str]
property_parsers: list[FieldDeserializer]
property_fields: set[str]
def __init__(self, class_type: Type[T]) -> None:
def __init__(self, class_type: type[T]) -> None:
self.class_type = class_type
def assign(self, property_parsers: List[FieldDeserializer]) -> None:
def assign(self, property_parsers: list[FieldDeserializer]) -> None:
self.property_parsers = property_parsers
self.property_fields = set(property_parser.property_name for property_parser in property_parsers)
self.property_fields = {property_parser.property_name for property_parser in property_parsers}
def parse(self, data: JsonType) -> T:
if not isinstance(data, dict):
type_name = python_type_to_str(self.class_type)
raise JsonTypeError(f"`type `{type_name}` expects JSON `object` data but instead received: {data}")
object_data: Dict[str, JsonType] = typing.cast(Dict[str, JsonType], data)
object_data: dict[str, JsonType] = typing.cast(dict[str, JsonType], data)
field_values = {}
for property_parser in self.property_parsers:
@ -619,8 +614,8 @@ class ClassDeserializer(Deserializer[T]):
class NamedTupleDeserializer(ClassDeserializer[NamedTuple]):
"De-serializes a named tuple from a JSON `object`."
def build(self, context: Optional[ModuleType]) -> None:
property_parsers: List[FieldDeserializer] = [
def build(self, context: ModuleType | None) -> None:
property_parsers: list[FieldDeserializer] = [
RequiredFieldDeserializer(field_name, field_name, _get_deserializer(field_type, context))
for field_name, field_type in get_resolved_hints(self.class_type).items()
]
@ -634,13 +629,13 @@ class NamedTupleDeserializer(ClassDeserializer[NamedTuple]):
class DataclassDeserializer(ClassDeserializer[T]):
"De-serializes a data class from a JSON `object`."
def __init__(self, class_type: Type[T]) -> None:
def __init__(self, class_type: type[T]) -> None:
if not dataclasses.is_dataclass(class_type):
raise TypeError("expected: data-class type")
super().__init__(class_type) # type: ignore[arg-type]
def build(self, context: Optional[ModuleType]) -> None:
property_parsers: List[FieldDeserializer] = []
def build(self, context: ModuleType | None) -> None:
property_parsers: list[FieldDeserializer] = []
resolved_hints = get_resolved_hints(self.class_type)
for field in dataclasses.fields(self.class_type):
field_type = resolved_hints[field.name]
@ -651,7 +646,7 @@ class DataclassDeserializer(ClassDeserializer[T]):
has_default_factory = field.default_factory is not dataclasses.MISSING
if is_optional:
required_type: Type[T] = unwrap_optional_type(field_type)
required_type: type[T] = unwrap_optional_type(field_type)
else:
required_type = field_type
@ -691,15 +686,15 @@ class FrozenDataclassDeserializer(DataclassDeserializer[T]):
class TypedClassDeserializer(ClassDeserializer[T]):
"De-serializes a class with type annotations from a JSON `object` by iterating over class properties."
def build(self, context: Optional[ModuleType]) -> None:
property_parsers: List[FieldDeserializer] = []
def build(self, context: ModuleType | None) -> None:
property_parsers: list[FieldDeserializer] = []
for field_name, field_type in get_resolved_hints(self.class_type).items():
property_name = python_field_to_json_property(field_name, field_type)
is_optional = is_type_optional(field_type)
if is_optional:
required_type: Type[T] = unwrap_optional_type(field_type)
required_type: type[T] = unwrap_optional_type(field_type)
else:
required_type = field_type
@ -715,7 +710,7 @@ class TypedClassDeserializer(ClassDeserializer[T]):
super().assign(property_parsers)
def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Deserializer:
def create_deserializer(typ: TypeLike, context: ModuleType | None = None) -> Deserializer:
"""
Creates a de-serializer engine to produce a Python object from an object obtained from a JSON string.
@ -741,15 +736,15 @@ def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) ->
return _get_deserializer(typ, context)
_CACHE: Dict[Tuple[str, str], Deserializer] = {}
_CACHE: dict[tuple[str, str], Deserializer] = {}
def _get_deserializer(typ: TypeLike, context: Optional[ModuleType]) -> Deserializer:
def _get_deserializer(typ: TypeLike, context: ModuleType | None) -> Deserializer:
"Creates or re-uses a de-serializer engine to parse an object obtained from a JSON string."
cache_key = None
if isinstance(typ, (str, typing.ForwardRef)):
if isinstance(typ, str | typing.ForwardRef):
if context is None:
raise TypeError(f"missing context for evaluating type: {typ}")

View file

@ -15,17 +15,12 @@ import collections.abc
import dataclasses
import inspect
import re
import sys
import types
import typing
from collections.abc import Callable
from dataclasses import dataclass
from io import StringIO
from typing import Any, Callable, Dict, Optional, Protocol, Type, TypeVar
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
from typing import Any, Protocol, TypeGuard, TypeVar
from .inspection import (
DataclassInstance,
@ -110,14 +105,14 @@ class Docstring:
:param returns: The returns declaration extracted from a docstring.
"""
short_description: Optional[str] = None
long_description: Optional[str] = None
params: Dict[str, DocstringParam] = dataclasses.field(default_factory=dict)
returns: Optional[DocstringReturns] = None
raises: Dict[str, DocstringRaises] = dataclasses.field(default_factory=dict)
short_description: str | None = None
long_description: str | None = None
params: dict[str, DocstringParam] = dataclasses.field(default_factory=dict)
returns: DocstringReturns | None = None
raises: dict[str, DocstringRaises] = dataclasses.field(default_factory=dict)
@property
def full_description(self) -> Optional[str]:
def full_description(self) -> str | None:
if self.short_description and self.long_description:
return f"{self.short_description}\n\n{self.long_description}"
elif self.short_description:
@ -158,18 +153,18 @@ class Docstring:
return s
def is_exception(member: object) -> TypeGuard[Type[BaseException]]:
def is_exception(member: object) -> TypeGuard[type[BaseException]]:
return isinstance(member, type) and issubclass(member, BaseException)
def get_exceptions(module: types.ModuleType) -> Dict[str, Type[BaseException]]:
def get_exceptions(module: types.ModuleType) -> dict[str, type[BaseException]]:
"Returns all exception classes declared in a module."
return {name: class_type for name, class_type in inspect.getmembers(module, is_exception)}
return dict(inspect.getmembers(module, is_exception))
class SupportsDoc(Protocol):
__doc__: Optional[str]
__doc__: str | None
def _maybe_unwrap_async_iterator(t):
@ -213,7 +208,7 @@ def parse_type(typ: SupportsDoc) -> Docstring:
# assign exception types
defining_module = inspect.getmodule(typ)
if defining_module:
context: Dict[str, type] = {}
context: dict[str, type] = {}
context.update(get_exceptions(builtins))
context.update(get_exceptions(defining_module))
for exc_name, exc in docstring.raises.items():
@ -262,8 +257,8 @@ def parse_text(text: str) -> Docstring:
else:
long_description = None
params: Dict[str, DocstringParam] = {}
raises: Dict[str, DocstringRaises] = {}
params: dict[str, DocstringParam] = {}
raises: dict[str, DocstringRaises] = {}
returns = None
for match in re.finditer(r"(^:.*?)(?=^:|\Z)", meta_chunk, flags=re.DOTALL | re.MULTILINE):
chunk = match.group(0)
@ -325,7 +320,7 @@ def has_docstring(typ: SupportsDoc) -> bool:
return bool(typ.__doc__)
def get_docstring(typ: SupportsDoc) -> Optional[str]:
def get_docstring(typ: SupportsDoc) -> str | None:
if typ.__doc__ is None:
return None
@ -348,7 +343,7 @@ def check_docstring(typ: SupportsDoc, docstring: Docstring, strict: bool = False
check_function_docstring(typ, docstring, strict)
def check_dataclass_docstring(typ: Type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None:
def check_dataclass_docstring(typ: type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None:
"""
Verifies the doc-string of a data-class type.

View file

@ -22,34 +22,19 @@ import sys
import types
import typing
import uuid
from collections.abc import Callable, Iterable
from typing import (
Annotated,
Any,
Callable,
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Protocol,
Set,
Tuple,
Type,
TypeGuard,
TypeVar,
Union,
runtime_checkable,
)
if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
S = TypeVar("S")
T = TypeVar("T")
K = TypeVar("K")
@ -80,28 +65,20 @@ def _is_type_like(data_type: object) -> bool:
return False
if sys.version_info >= (3, 9):
TypeLike = Union[type, types.GenericAlias, typing.ForwardRef, Any]
TypeLike = Union[type, types.GenericAlias, typing.ForwardRef, Any]
def is_type_like(
data_type: object,
) -> TypeGuard[TypeLike]:
"""
Checks if the object is a type or type-like object (e.g. generic type).
:param data_type: The object to validate.
:returns: True if the object is a type or type-like object.
"""
def is_type_like(
data_type: object,
) -> TypeGuard[TypeLike]:
"""
Checks if the object is a type or type-like object (e.g. generic type).
return _is_type_like(data_type)
:param data_type: The object to validate.
:returns: True if the object is a type or type-like object.
"""
else:
TypeLike = object
def is_type_like(
data_type: object,
) -> bool:
return _is_type_like(data_type)
return _is_type_like(data_type)
def evaluate_member_type(typ: Any, cls: type) -> Any:
@ -129,20 +106,17 @@ def evaluate_type(typ: Any, module: types.ModuleType) -> Any:
# evaluate data-class field whose type annotation is a string
return eval(typ, module.__dict__, locals())
if isinstance(typ, typing.ForwardRef):
if sys.version_info >= (3, 9):
return typ._evaluate(module.__dict__, locals(), recursive_guard=frozenset())
else:
return typ._evaluate(module.__dict__, locals())
return typ._evaluate(module.__dict__, locals(), recursive_guard=frozenset())
else:
return typ
@runtime_checkable
class DataclassInstance(Protocol):
__dataclass_fields__: typing.ClassVar[Dict[str, dataclasses.Field]]
__dataclass_fields__: typing.ClassVar[dict[str, dataclasses.Field]]
def is_dataclass_type(typ: Any) -> TypeGuard[Type[DataclassInstance]]:
def is_dataclass_type(typ: Any) -> TypeGuard[type[DataclassInstance]]:
"True if the argument corresponds to a data class type (but not an instance)."
typ = unwrap_annotated_type(typ)
@ -167,14 +141,14 @@ class DataclassField:
self.default = default
def dataclass_fields(cls: Type[DataclassInstance]) -> Iterable[DataclassField]:
def dataclass_fields(cls: type[DataclassInstance]) -> Iterable[DataclassField]:
"Generates the fields of a data-class resolving forward references."
for field in dataclasses.fields(cls):
yield DataclassField(field.name, evaluate_member_type(field.type, cls), field.default)
def dataclass_field_by_name(cls: Type[DataclassInstance], name: str) -> DataclassField:
def dataclass_field_by_name(cls: type[DataclassInstance], name: str) -> DataclassField:
"Looks up a field in a data-class by its field name."
for field in dataclasses.fields(cls):
@ -190,7 +164,7 @@ def is_named_tuple_instance(obj: Any) -> TypeGuard[NamedTuple]:
return is_named_tuple_type(type(obj))
def is_named_tuple_type(typ: Any) -> TypeGuard[Type[NamedTuple]]:
def is_named_tuple_type(typ: Any) -> TypeGuard[type[NamedTuple]]:
"""
True if the argument corresponds to a named tuple type.
@ -217,26 +191,14 @@ def is_named_tuple_type(typ: Any) -> TypeGuard[Type[NamedTuple]]:
return all(isinstance(n, str) for n in f)
if sys.version_info >= (3, 11):
def is_type_enum(typ: object) -> TypeGuard[type[enum.Enum]]:
"True if the specified type is an enumeration type."
def is_type_enum(typ: object) -> TypeGuard[Type[enum.Enum]]:
"True if the specified type is an enumeration type."
typ = unwrap_annotated_type(typ)
return isinstance(typ, enum.EnumType)
else:
def is_type_enum(typ: object) -> TypeGuard[Type[enum.Enum]]:
"True if the specified type is an enumeration type."
typ = unwrap_annotated_type(typ)
# use an explicit isinstance(..., type) check to filter out special forms like generics
return isinstance(typ, type) and issubclass(typ, enum.Enum)
typ = unwrap_annotated_type(typ)
return isinstance(typ, enum.EnumType)
def enum_value_types(enum_type: Type[enum.Enum]) -> List[type]:
def enum_value_types(enum_type: type[enum.Enum]) -> list[type]:
"""
Returns all unique value types of the `enum.Enum` type in definition order.
"""
@ -246,8 +208,8 @@ def enum_value_types(enum_type: Type[enum.Enum]) -> List[type]:
def extend_enum(
source: Type[enum.Enum],
) -> Callable[[Type[enum.Enum]], Type[enum.Enum]]:
source: type[enum.Enum],
) -> Callable[[type[enum.Enum]], type[enum.Enum]]:
"""
Creates a new enumeration type extending the set of values in an existing type.
@ -255,13 +217,13 @@ def extend_enum(
:returns: A new enumeration type with the extended set of values.
"""
def wrap(extend: Type[enum.Enum]) -> Type[enum.Enum]:
def wrap(extend: type[enum.Enum]) -> type[enum.Enum]:
# create new enumeration type combining the values from both types
values: Dict[str, Any] = {}
values: dict[str, Any] = {}
values.update((e.name, e.value) for e in source)
values.update((e.name, e.value) for e in extend)
# mypy fails to determine that __name__ is always a string; hence the `ignore` directive.
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc]
enum_class: type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc]
# assign the newly created type to the same module where the extending class is defined
enum_class.__module__ = extend.__module__
@ -273,22 +235,13 @@ def extend_enum(
return wrap
if sys.version_info >= (3, 10):
def _is_union_like(typ: object) -> bool:
"True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`."
def _is_union_like(typ: object) -> bool:
"True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`."
return typing.get_origin(typ) is Union or isinstance(typ, types.UnionType)
else:
def _is_union_like(typ: object) -> bool:
"True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`."
return typing.get_origin(typ) is Union
return typing.get_origin(typ) is Union or isinstance(typ, types.UnionType)
def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[Type[Optional[Any]]]:
def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[type[Any | None]]:
"""
True if the type annotation corresponds to an optional type (e.g. `Optional[T]` or `Union[T1,T2,None]`).
@ -309,7 +262,7 @@ def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[Type[Option
return False
def unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
def unwrap_optional_type(typ: type[T | None]) -> type[T]:
"""
Extracts the inner type of an optional type.
@ -320,7 +273,7 @@ def unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
return rewrap_annotated_type(_unwrap_optional_type, typ)
def _unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
def _unwrap_optional_type(typ: type[T | None]) -> type[T]:
"Extracts the type qualified as optional (e.g. returns `T` for `Optional[T]`)."
# Optional[T] is represented internally as Union[T, None]
@ -342,7 +295,7 @@ def is_type_union(typ: object) -> bool:
return False
def unwrap_union_types(typ: object) -> Tuple[object, ...]:
def unwrap_union_types(typ: object) -> tuple[object, ...]:
"""
Extracts the inner types of a union type.
@ -354,7 +307,7 @@ def unwrap_union_types(typ: object) -> Tuple[object, ...]:
return _unwrap_union_types(typ)
def _unwrap_union_types(typ: object) -> Tuple[object, ...]:
def _unwrap_union_types(typ: object) -> tuple[object, ...]:
"Extracts the types in a union (e.g. returns a tuple of types `T1` and `T2` for `Union[T1, T2]`)."
if not _is_union_like(typ):
@ -385,7 +338,7 @@ def unwrap_literal_value(typ: object) -> Any:
return args[0]
def unwrap_literal_values(typ: object) -> Tuple[Any, ...]:
def unwrap_literal_values(typ: object) -> tuple[Any, ...]:
"""
Extracts the constant values captured by a literal type.
@ -397,7 +350,7 @@ def unwrap_literal_values(typ: object) -> Tuple[Any, ...]:
return typing.get_args(typ)
def unwrap_literal_types(typ: object) -> Tuple[type, ...]:
def unwrap_literal_types(typ: object) -> tuple[type, ...]:
"""
Extracts the types of the constant values captured by a literal type.
@ -408,14 +361,14 @@ def unwrap_literal_types(typ: object) -> Tuple[type, ...]:
return tuple(type(t) for t in unwrap_literal_values(typ))
def is_generic_list(typ: object) -> TypeGuard[Type[list]]:
def is_generic_list(typ: object) -> TypeGuard[type[list]]:
"True if the specified type is a generic list, i.e. `List[T]`."
typ = unwrap_annotated_type(typ)
return typing.get_origin(typ) is list
def unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
def unwrap_generic_list(typ: type[list[T]]) -> type[T]:
"""
Extracts the item type of a list type.
@ -426,21 +379,21 @@ def unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
return rewrap_annotated_type(_unwrap_generic_list, typ)
def _unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
def _unwrap_generic_list(typ: type[list[T]]) -> type[T]:
"Extracts the item type of a list type (e.g. returns `T` for `List[T]`)."
(list_type,) = typing.get_args(typ) # unpack single tuple element
return list_type # type: ignore[no-any-return]
def is_generic_set(typ: object) -> TypeGuard[Type[set]]:
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
"True if the specified type is a generic set, i.e. `Set[T]`."
typ = unwrap_annotated_type(typ)
return typing.get_origin(typ) is set
def unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
def unwrap_generic_set(typ: type[set[T]]) -> type[T]:
"""
Extracts the item type of a set type.
@ -451,21 +404,21 @@ def unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
return rewrap_annotated_type(_unwrap_generic_set, typ)
def _unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
def _unwrap_generic_set(typ: type[set[T]]) -> type[T]:
"Extracts the item type of a set type (e.g. returns `T` for `Set[T]`)."
(set_type,) = typing.get_args(typ) # unpack single tuple element
return set_type # type: ignore[no-any-return]
def is_generic_dict(typ: object) -> TypeGuard[Type[dict]]:
def is_generic_dict(typ: object) -> TypeGuard[type[dict]]:
"True if the specified type is a generic dictionary, i.e. `Dict[KeyType, ValueType]`."
typ = unwrap_annotated_type(typ)
return typing.get_origin(typ) is dict
def unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]:
def unwrap_generic_dict(typ: type[dict[K, V]]) -> tuple[type[K], type[V]]:
"""
Extracts the key and value types of a dictionary type as a tuple.
@ -476,7 +429,7 @@ def unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]:
return _unwrap_generic_dict(unwrap_annotated_type(typ))
def _unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]:
def _unwrap_generic_dict(typ: type[dict[K, V]]) -> tuple[type[K], type[V]]:
"Extracts the key and value types of a dict type (e.g. returns (`K`, `V`) for `Dict[K, V]`)."
key_type, value_type = typing.get_args(typ)
@ -489,7 +442,7 @@ def is_type_annotated(typ: TypeLike) -> bool:
return getattr(typ, "__metadata__", None) is not None
def get_annotation(data_type: TypeLike, annotation_type: Type[T]) -> Optional[T]:
def get_annotation(data_type: TypeLike, annotation_type: type[T]) -> T | None:
"""
Returns the first annotation on a data type that matches the expected annotation type.
@ -518,7 +471,7 @@ def unwrap_annotated_type(typ: T) -> T:
return typ
def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S]) -> Type[T]:
def rewrap_annotated_type(transform: Callable[[type[S]], type[T]], typ: type[S]) -> type[T]:
"""
Un-boxes, transforms and re-boxes an optionally annotated type.
@ -542,7 +495,7 @@ def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S])
return transformed_type
def get_module_classes(module: types.ModuleType) -> List[type]:
def get_module_classes(module: types.ModuleType) -> list[type]:
"Returns all classes declared directly in a module."
def is_class_member(member: object) -> TypeGuard[type]:
@ -551,18 +504,11 @@ def get_module_classes(module: types.ModuleType) -> List[type]:
return [class_type for _, class_type in inspect.getmembers(module, is_class_member)]
if sys.version_info >= (3, 9):
def get_resolved_hints(typ: type) -> Dict[str, type]:
return typing.get_type_hints(typ, include_extras=True)
else:
def get_resolved_hints(typ: type) -> Dict[str, type]:
return typing.get_type_hints(typ)
def get_resolved_hints(typ: type) -> dict[str, type]:
return typing.get_type_hints(typ, include_extras=True)
def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
def get_class_properties(typ: type) -> Iterable[tuple[str, type | str]]:
"Returns all properties of a class."
if is_dataclass_type(typ):
@ -572,7 +518,7 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
return resolved_hints.items()
def get_class_property(typ: type, name: str) -> Optional[type | str]:
def get_class_property(typ: type, name: str) -> type | str | None:
"Looks up the annotated type of a property in a class by its property name."
for property_name, property_type in get_class_properties(typ):
@ -586,7 +532,7 @@ class _ROOT:
pass
def get_referenced_types(typ: TypeLike, module: Optional[types.ModuleType] = None) -> Set[type]:
def get_referenced_types(typ: TypeLike, module: types.ModuleType | None = None) -> set[type]:
"""
Extracts types directly or indirectly referenced by this type.
@ -610,10 +556,10 @@ class TypeCollector:
:param graph: The type dependency graph, linking types to types they depend on.
"""
graph: Dict[type, Set[type]]
graph: dict[type, set[type]]
@property
def references(self) -> Set[type]:
def references(self) -> set[type]:
"Types collected by the type collector."
dependencies = set()
@ -638,8 +584,8 @@ class TypeCollector:
def run(
self,
typ: TypeLike,
cls: Type[DataclassInstance],
module: Optional[types.ModuleType],
cls: type[DataclassInstance],
module: types.ModuleType | None,
) -> None:
"""
Extracts types indirectly referenced by this type.
@ -702,26 +648,17 @@ class TypeCollector:
for field in dataclass_fields(typ):
self.run(field.type, typ, context)
else:
for field_name, field_type in get_resolved_hints(typ).items():
for _field_name, field_type in get_resolved_hints(typ).items():
self.run(field_type, typ, context)
return
raise TypeError(f"expected: type-like; got: {typ}")
if sys.version_info >= (3, 10):
def get_signature(fn: Callable[..., Any]) -> inspect.Signature:
"Extracts the signature of a function."
def get_signature(fn: Callable[..., Any]) -> inspect.Signature:
"Extracts the signature of a function."
return inspect.signature(fn, eval_str=True)
else:
def get_signature(fn: Callable[..., Any]) -> inspect.Signature:
"Extracts the signature of a function."
return inspect.signature(fn)
return inspect.signature(fn, eval_str=True)
def is_reserved_property(name: str) -> bool:
@ -756,51 +693,20 @@ def create_module(name: str) -> types.ModuleType:
return module
if sys.version_info >= (3, 10):
def create_data_type(class_name: str, fields: list[tuple[str, type]]) -> type:
"""
Creates a new data-class type dynamically.
def create_data_type(class_name: str, fields: List[Tuple[str, type]]) -> type:
"""
Creates a new data-class type dynamically.
:param class_name: The name of new data-class type.
:param fields: A list of fields (and their type) that the new data-class type is expected to have.
:returns: The newly created data-class type.
"""
:param class_name: The name of new data-class type.
:param fields: A list of fields (and their type) that the new data-class type is expected to have.
:returns: The newly created data-class type.
"""
# has the `slots` parameter
return dataclasses.make_dataclass(class_name, fields, slots=True)
else:
def create_data_type(class_name: str, fields: List[Tuple[str, type]]) -> type:
"""
Creates a new data-class type dynamically.
:param class_name: The name of new data-class type.
:param fields: A list of fields (and their type) that the new data-class type is expected to have.
:returns: The newly created data-class type.
"""
cls = dataclasses.make_dataclass(class_name, fields)
cls_dict = dict(cls.__dict__)
field_names = tuple(field.name for field in dataclasses.fields(cls))
cls_dict["__slots__"] = field_names
for field_name in field_names:
cls_dict.pop(field_name, None)
cls_dict.pop("__dict__", None)
qualname = getattr(cls, "__qualname__", None)
cls = type(cls)(cls.__name__, (), cls_dict)
if qualname is not None:
cls.__qualname__ = qualname
return cls
# has the `slots` parameter
return dataclasses.make_dataclass(class_name, fields, slots=True)
def create_object(typ: Type[T]) -> T:
def create_object(typ: type[T]) -> T:
"Creates an instance of a type."
if issubclass(typ, Exception):
@ -811,11 +717,7 @@ def create_object(typ: Type[T]) -> T:
return object.__new__(typ)
if sys.version_info >= (3, 9):
TypeOrGeneric = Union[type, types.GenericAlias]
else:
TypeOrGeneric = object
TypeOrGeneric = Union[type, types.GenericAlias]
def is_generic_instance(obj: Any, typ: TypeLike) -> bool:
@ -885,7 +787,7 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool:
class RecursiveChecker:
_pred: Optional[Callable[[type, Any], bool]]
_pred: Callable[[type, Any], bool] | None
def __init__(self, pred: Callable[[type, Any], bool]) -> None:
"""
@ -997,9 +899,9 @@ def check_recursive(
obj: object,
/,
*,
pred: Optional[Callable[[type, Any], bool]] = None,
type_pred: Optional[Callable[[type], bool]] = None,
value_pred: Optional[Callable[[Any], bool]] = None,
pred: Callable[[type, Any], bool] | None = None,
type_pred: Callable[[type], bool] | None = None,
value_pred: Callable[[Any], bool] | None = None,
) -> bool:
"""
Checks if a predicate applies to all nested member properties of an object recursively.
@ -1015,7 +917,7 @@ def check_recursive(
if pred is not None:
raise TypeError("filter predicate not permitted when type and value predicates are present")
type_p: Callable[[Type[T]], bool] = type_pred
type_p: Callable[[type[T]], bool] = type_pred
value_p: Callable[[T], bool] = value_pred
pred = lambda typ, obj: not type_p(typ) or value_p(obj) # noqa: E731

View file

@ -11,13 +11,12 @@ Type-safe data interchange for Python data classes.
"""
import keyword
from typing import Optional
from .auxiliary import Alias
from .inspection import get_annotation
def python_field_to_json_property(python_id: str, python_type: Optional[object] = None) -> str:
def python_field_to_json_property(python_id: str, python_type: object | None = None) -> str:
"""
Map a Python field identifier to a JSON property name.

View file

@ -11,7 +11,7 @@ Type-safe data interchange for Python data classes.
"""
import typing
from typing import Any, Literal, Optional, Tuple, Union
from typing import Any, Literal, Union
from .auxiliary import _auxiliary_types
from .inspection import (
@ -39,7 +39,7 @@ class TypeFormatter:
def __init__(self, use_union_operator: bool = False) -> None:
self.use_union_operator = use_union_operator
def union_to_str(self, data_type_args: Tuple[TypeLike, ...]) -> str:
def union_to_str(self, data_type_args: tuple[TypeLike, ...]) -> str:
if self.use_union_operator:
return " | ".join(self.python_type_to_str(t) for t in data_type_args)
else:
@ -100,7 +100,7 @@ class TypeFormatter:
metadata = getattr(data_type, "__metadata__", None)
if metadata is not None:
# type is Annotated[T, ...]
metatuple: Tuple[Any, ...] = metadata
metatuple: tuple[Any, ...] = metadata
arg = typing.get_args(data_type)[0]
# check for auxiliary types with user-defined annotations
@ -110,7 +110,7 @@ class TypeFormatter:
if arg is not auxiliary_arg:
continue
auxiliary_metatuple: Optional[Tuple[Any, ...]] = getattr(auxiliary_type, "__metadata__", None)
auxiliary_metatuple: tuple[Any, ...] | None = getattr(auxiliary_type, "__metadata__", None)
if auxiliary_metatuple is None:
continue

View file

@ -21,24 +21,19 @@ import json
import types
import typing
import uuid
from collections.abc import Callable
from copy import deepcopy
from typing import (
Annotated,
Any,
Callable,
ClassVar,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
import jsonschema
from typing_extensions import Annotated
from . import docstring
from .auxiliary import (
@ -71,7 +66,7 @@ OBJECT_ENUM_EXPANSION_LIMIT = 4
T = TypeVar("T")
def get_class_docstrings(data_type: type) -> Tuple[Optional[str], Optional[str]]:
def get_class_docstrings(data_type: type) -> tuple[str | None, str | None]:
docstr = docstring.parse_type(data_type)
# check if class has a doc-string other than the auto-generated string assigned by @dataclass
@ -82,8 +77,8 @@ def get_class_docstrings(data_type: type) -> Tuple[Optional[str], Optional[str]]
def get_class_property_docstrings(
data_type: type, transform_fun: Optional[Callable[[type, str, str], str]] = None
) -> Dict[str, str]:
data_type: type, transform_fun: Callable[[type, str, str], str] | None = None
) -> dict[str, str]:
"""
Extracts the documentation strings associated with the properties of a composite type.
@ -120,7 +115,7 @@ def docstring_to_schema(data_type: type) -> Schema:
return schema
def id_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> str:
def id_from_ref(data_type: typing.ForwardRef | str | type) -> str:
"Extracts the name of a possibly forward-referenced type."
if isinstance(data_type, typing.ForwardRef):
@ -132,7 +127,7 @@ def id_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> str:
return data_type.__name__
def type_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> Tuple[str, type]:
def type_from_ref(data_type: typing.ForwardRef | str | type) -> tuple[str, type]:
"Creates a type from a forward reference."
if isinstance(data_type, typing.ForwardRef):
@ -148,16 +143,16 @@ def type_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> Tuple[str,
@dataclasses.dataclass
class TypeCatalogEntry:
schema: Optional[Schema]
schema: Schema | None
identifier: str
examples: Optional[JsonType] = None
examples: JsonType | None = None
class TypeCatalog:
"Maintains an association of well-known Python types to their JSON schema."
_by_type: Dict[TypeLike, TypeCatalogEntry]
_by_name: Dict[str, TypeCatalogEntry]
_by_type: dict[TypeLike, TypeCatalogEntry]
_by_name: dict[str, TypeCatalogEntry]
def __init__(self) -> None:
self._by_type = {}
@ -174,9 +169,9 @@ class TypeCatalog:
def add(
self,
data_type: TypeLike,
schema: Optional[Schema],
schema: Schema | None,
identifier: str,
examples: Optional[List[JsonType]] = None,
examples: list[JsonType] | None = None,
) -> None:
if isinstance(data_type, typing.ForwardRef):
raise TypeError("forward references cannot be used to register a type")
@ -202,17 +197,17 @@ class SchemaOptions:
definitions_path: str = "#/definitions/"
use_descriptions: bool = True
use_examples: bool = True
property_description_fun: Optional[Callable[[type, str, str], str]] = None
property_description_fun: Callable[[type, str, str], str] | None = None
class JsonSchemaGenerator:
"Creates a JSON schema with user-defined type definitions."
type_catalog: ClassVar[TypeCatalog] = TypeCatalog()
types_used: Dict[str, TypeLike]
types_used: dict[str, TypeLike]
options: SchemaOptions
def __init__(self, options: Optional[SchemaOptions] = None):
def __init__(self, options: SchemaOptions | None = None):
if options is None:
self.options = SchemaOptions()
else:
@ -244,13 +239,13 @@ class JsonSchemaGenerator:
def _(self, arg: MaxLength) -> Schema:
return {"maxLength": arg.value}
def _with_metadata(self, type_schema: Schema, metadata: Optional[Tuple[Any, ...]]) -> Schema:
def _with_metadata(self, type_schema: Schema, metadata: tuple[Any, ...] | None) -> Schema:
if metadata:
for m in metadata:
type_schema.update(self._metadata_to_schema(m))
return type_schema
def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: Optional[dict] = None) -> Optional[Schema]:
def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: dict | None = None) -> Schema | None:
"""
Returns the JSON schema associated with a simple, unrestricted type.
@ -314,7 +309,7 @@ class JsonSchemaGenerator:
self,
data_type: TypeLike,
force_expand: bool = False,
json_schema_extra: Optional[dict] = None,
json_schema_extra: dict | None = None,
) -> Schema:
common_info = {}
if json_schema_extra and "deprecated" in json_schema_extra:
@ -325,7 +320,7 @@ class JsonSchemaGenerator:
self,
data_type: TypeLike,
force_expand: bool = False,
json_schema_extra: Optional[dict] = None,
json_schema_extra: dict | None = None,
) -> Schema:
"""
Returns the JSON schema associated with a type.
@ -381,7 +376,7 @@ class JsonSchemaGenerator:
return {"$ref": f"{self.options.definitions_path}{identifier}"}
if is_type_enum(typ):
enum_type: Type[enum.Enum] = typ
enum_type: type[enum.Enum] = typ
value_types = enum_value_types(enum_type)
if len(value_types) != 1:
raise ValueError(
@ -496,8 +491,8 @@ class JsonSchemaGenerator:
members = dict(inspect.getmembers(typ, lambda a: not inspect.isroutine(a)))
property_docstrings = get_class_property_docstrings(typ, self.options.property_description_fun)
properties: Dict[str, Schema] = {}
required: List[str] = []
properties: dict[str, Schema] = {}
required: list[str] = []
for property_name, property_type in get_class_properties(typ):
# rename property if an alias name is specified
alias = get_annotation(property_type, Alias)
@ -530,16 +525,7 @@ class JsonSchemaGenerator:
# check if value can be directly represented in JSON
if isinstance(
def_value,
(
bool,
int,
float,
str,
enum.Enum,
datetime.datetime,
datetime.date,
datetime.time,
),
bool | int | float | str | enum.Enum | datetime.datetime | datetime.date | datetime.time,
):
property_def["default"] = object_to_json(def_value)
@ -587,7 +573,7 @@ class JsonSchemaGenerator:
return type_schema
def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> Tuple[Schema, Dict[str, Schema]]:
def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> tuple[Schema, dict[str, Schema]]:
"""
Returns the JSON schema associated with a type and any nested types.
@ -604,7 +590,7 @@ class JsonSchemaGenerator:
try:
type_schema = self.type_to_schema(data_type, force_expand=force_expand)
types_defined: Dict[str, Schema] = {}
types_defined: dict[str, Schema] = {}
while len(self.types_used) > len(types_defined):
# make a snapshot copy; original collection is going to be modified
types_undefined = {
@ -635,7 +621,7 @@ class Validator(enum.Enum):
def classdef_to_schema(
data_type: TypeLike,
options: Optional[SchemaOptions] = None,
options: SchemaOptions | None = None,
validator: Validator = Validator.Latest,
) -> Schema:
"""
@ -689,7 +675,7 @@ def print_schema(data_type: type) -> None:
print(json.dumps(s, indent=4))
def get_schema_identifier(data_type: type) -> Optional[str]:
def get_schema_identifier(data_type: type) -> str | None:
if data_type in JsonSchemaGenerator.type_catalog:
return JsonSchemaGenerator.type_catalog.get(data_type).identifier
else:
@ -698,9 +684,9 @@ def get_schema_identifier(data_type: type) -> Optional[str]:
def register_schema(
data_type: T,
schema: Optional[Schema] = None,
name: Optional[str] = None,
examples: Optional[List[JsonType]] = None,
schema: Schema | None = None,
name: str | None = None,
examples: list[JsonType] | None = None,
) -> T:
"""
Associates a type with a JSON schema definition.
@ -721,22 +707,22 @@ def register_schema(
@overload
def json_schema_type(cls: Type[T], /) -> Type[T]: ...
def json_schema_type(cls: type[T], /) -> type[T]: ...
@overload
def json_schema_type(cls: None, *, schema: Optional[Schema] = None) -> Callable[[Type[T]], Type[T]]: ...
def json_schema_type(cls: None, *, schema: Schema | None = None) -> Callable[[type[T]], type[T]]: ...
def json_schema_type(
cls: Optional[Type[T]] = None,
cls: type[T] | None = None,
*,
schema: Optional[Schema] = None,
examples: Optional[List[JsonType]] = None,
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
schema: Schema | None = None,
examples: list[JsonType] | None = None,
) -> type[T] | Callable[[type[T]], type[T]]:
"""Decorator to add user-defined schema definition to a class."""
def wrap(cls: Type[T]) -> Type[T]:
def wrap(cls: type[T]) -> type[T]:
return register_schema(cls, schema, examples=examples)
# see if decorator is used as @json_schema_type or @json_schema_type()

View file

@ -14,7 +14,7 @@ import inspect
import json
import sys
from types import ModuleType
from typing import Any, Optional, TextIO, TypeVar
from typing import Any, TextIO, TypeVar
from .core import JsonType
from .deserializer import create_deserializer
@ -42,7 +42,7 @@ def object_to_json(obj: Any) -> JsonType:
return generator.generate(obj)
def json_to_object(typ: TypeLike, data: JsonType, *, context: Optional[ModuleType] = None) -> object:
def json_to_object(typ: TypeLike, data: JsonType, *, context: ModuleType | None = None) -> object:
"""
Creates an object from a representation that has been de-serialized from JSON.

View file

@ -20,19 +20,13 @@ import ipaddress
import sys
import typing
import uuid
from collections.abc import Callable
from types import FunctionType, MethodType, ModuleType
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Literal,
NamedTuple,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
@ -133,7 +127,7 @@ class IPv6Serializer(Serializer[ipaddress.IPv6Address]):
class EnumSerializer(Serializer[enum.Enum]):
def generate(self, obj: enum.Enum) -> Union[int, str]:
def generate(self, obj: enum.Enum) -> int | str:
value = obj.value
if isinstance(value, int):
return value
@ -141,12 +135,12 @@ class EnumSerializer(Serializer[enum.Enum]):
class UntypedListSerializer(Serializer[list]):
def generate(self, obj: list) -> List[JsonType]:
def generate(self, obj: list) -> list[JsonType]:
return [object_to_json(item) for item in obj]
class UntypedDictSerializer(Serializer[dict]):
def generate(self, obj: dict) -> Dict[str, JsonType]:
def generate(self, obj: dict) -> dict[str, JsonType]:
if obj and isinstance(next(iter(obj.keys())), enum.Enum):
iterator = ((key.value, object_to_json(value)) for key, value in obj.items())
else:
@ -155,41 +149,41 @@ class UntypedDictSerializer(Serializer[dict]):
class UntypedSetSerializer(Serializer[set]):
def generate(self, obj: set) -> List[JsonType]:
def generate(self, obj: set) -> list[JsonType]:
return [object_to_json(item) for item in obj]
class UntypedTupleSerializer(Serializer[tuple]):
def generate(self, obj: tuple) -> List[JsonType]:
def generate(self, obj: tuple) -> list[JsonType]:
return [object_to_json(item) for item in obj]
class TypedCollectionSerializer(Serializer, Generic[T]):
generator: Serializer[T]
def __init__(self, item_type: Type[T], context: Optional[ModuleType]) -> None:
def __init__(self, item_type: type[T], context: ModuleType | None) -> None:
self.generator = _get_serializer(item_type, context)
class TypedListSerializer(TypedCollectionSerializer[T]):
def generate(self, obj: List[T]) -> List[JsonType]:
def generate(self, obj: list[T]) -> list[JsonType]:
return [self.generator.generate(item) for item in obj]
class TypedStringDictSerializer(TypedCollectionSerializer[T]):
def __init__(self, value_type: Type[T], context: Optional[ModuleType]) -> None:
def __init__(self, value_type: type[T], context: ModuleType | None) -> None:
super().__init__(value_type, context)
def generate(self, obj: Dict[str, T]) -> Dict[str, JsonType]:
def generate(self, obj: dict[str, T]) -> dict[str, JsonType]:
return {key: self.generator.generate(value) for key, value in obj.items()}
class TypedEnumDictSerializer(TypedCollectionSerializer[T]):
def __init__(
self,
key_type: Type[enum.Enum],
value_type: Type[T],
context: Optional[ModuleType],
key_type: type[enum.Enum],
value_type: type[T],
context: ModuleType | None,
) -> None:
super().__init__(value_type, context)
@ -203,22 +197,22 @@ class TypedEnumDictSerializer(TypedCollectionSerializer[T]):
if value_type is not str:
raise JsonTypeError("invalid enumeration key type, expected `enum.Enum` with string values")
def generate(self, obj: Dict[enum.Enum, T]) -> Dict[str, JsonType]:
def generate(self, obj: dict[enum.Enum, T]) -> dict[str, JsonType]:
return {key.value: self.generator.generate(value) for key, value in obj.items()}
class TypedSetSerializer(TypedCollectionSerializer[T]):
def generate(self, obj: Set[T]) -> JsonType:
def generate(self, obj: set[T]) -> JsonType:
return [self.generator.generate(item) for item in obj]
class TypedTupleSerializer(Serializer[tuple]):
item_generators: Tuple[Serializer, ...]
item_generators: tuple[Serializer, ...]
def __init__(self, item_types: Tuple[type, ...], context: Optional[ModuleType]) -> None:
def __init__(self, item_types: tuple[type, ...], context: ModuleType | None) -> None:
self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types)
def generate(self, obj: tuple) -> List[JsonType]:
def generate(self, obj: tuple) -> list[JsonType]:
return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj, strict=False)]
@ -250,16 +244,16 @@ class FieldSerializer(Generic[T]):
self.property_name = property_name
self.generator = generator
def generate_field(self, obj: object, object_dict: Dict[str, JsonType]) -> None:
def generate_field(self, obj: object, object_dict: dict[str, JsonType]) -> None:
value = getattr(obj, self.field_name)
if value is not None:
object_dict[self.property_name] = self.generator.generate(value)
class TypedClassSerializer(Serializer[T]):
property_generators: List[FieldSerializer]
property_generators: list[FieldSerializer]
def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None:
def __init__(self, class_type: type[T], context: ModuleType | None) -> None:
self.property_generators = [
FieldSerializer(
field_name,
@ -269,8 +263,8 @@ class TypedClassSerializer(Serializer[T]):
for field_name, field_type in get_class_properties(class_type)
]
def generate(self, obj: T) -> Dict[str, JsonType]:
object_dict: Dict[str, JsonType] = {}
def generate(self, obj: T) -> dict[str, JsonType]:
object_dict: dict[str, JsonType] = {}
for property_generator in self.property_generators:
property_generator.generate_field(obj, object_dict)
@ -278,12 +272,12 @@ class TypedClassSerializer(Serializer[T]):
class TypedNamedTupleSerializer(TypedClassSerializer[NamedTuple]):
def __init__(self, class_type: Type[NamedTuple], context: Optional[ModuleType]) -> None:
def __init__(self, class_type: type[NamedTuple], context: ModuleType | None) -> None:
super().__init__(class_type, context)
class DataclassSerializer(TypedClassSerializer[T]):
def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None:
def __init__(self, class_type: type[T], context: ModuleType | None) -> None:
super().__init__(class_type, context)
@ -295,7 +289,7 @@ class UnionSerializer(Serializer):
class LiteralSerializer(Serializer):
generator: Serializer
def __init__(self, values: Tuple[Any, ...], context: Optional[ModuleType]) -> None:
def __init__(self, values: tuple[Any, ...], context: ModuleType | None) -> None:
literal_type_tuple = tuple(type(value) for value in values)
literal_type_set = set(literal_type_tuple)
if len(literal_type_set) != 1:
@ -312,12 +306,12 @@ class LiteralSerializer(Serializer):
class UntypedNamedTupleSerializer(Serializer):
fields: Dict[str, str]
fields: dict[str, str]
def __init__(self, class_type: Type[NamedTuple]) -> None:
def __init__(self, class_type: type[NamedTuple]) -> None:
# named tuples are also instances of tuple
self.fields = {}
field_names: Tuple[str, ...] = class_type._fields
field_names: tuple[str, ...] = class_type._fields
for field_name in field_names:
self.fields[field_name] = python_field_to_json_property(field_name)
@ -351,7 +345,7 @@ class UntypedClassSerializer(Serializer):
return object_dict
def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Serializer:
def create_serializer(typ: TypeLike, context: ModuleType | None = None) -> Serializer:
"""
Creates a serializer engine to produce an object that can be directly converted into a JSON string.
@ -376,8 +370,8 @@ def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Se
return _get_serializer(typ, context)
def _get_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
if isinstance(typ, (str, typing.ForwardRef)):
def _get_serializer(typ: TypeLike, context: ModuleType | None) -> Serializer:
if isinstance(typ, str | typing.ForwardRef):
if context is None:
raise TypeError(f"missing context for evaluating type: {typ}")
@ -390,13 +384,13 @@ def _get_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
return _create_serializer(typ, context)
@functools.lru_cache(maxsize=None)
@functools.cache
def _fetch_serializer(typ: type) -> Serializer:
context = sys.modules[typ.__module__]
return _create_serializer(typ, context)
def _create_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
def _create_serializer(typ: TypeLike, context: ModuleType | None) -> Serializer:
# check for well-known types
if typ is type(None):
return NoneSerializer()

View file

@ -4,18 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Tuple, Type, TypeVar
from typing import Any, TypeVar
T = TypeVar("T")
class SlotsMeta(type):
def __new__(cls: Type[T], name: str, bases: Tuple[type, ...], ns: Dict[str, Any]) -> T:
def __new__(cls: type[T], name: str, bases: tuple[type, ...], ns: dict[str, Any]) -> T:
# caller may have already provided slots, in which case just retain them and keep going
slots: Tuple[str, ...] = ns.get("__slots__", ())
slots: tuple[str, ...] = ns.get("__slots__", ())
# add fields with type annotations to slots
annotations: Dict[str, Any] = ns.get("__annotations__", {})
annotations: dict[str, Any] = ns.get("__annotations__", {})
members = tuple(member for member in annotations.keys() if member not in slots)
# assign slots

View file

@ -10,14 +10,15 @@ Type-safe data interchange for Python data classes.
:see: https://github.com/hunyadi/strong_typing
"""
from typing import Callable, Dict, Iterable, List, Optional, Set, TypeVar
from collections.abc import Callable, Iterable
from typing import TypeVar
from .inspection import TypeCollector
T = TypeVar("T")
def topological_sort(graph: Dict[T, Set[T]]) -> List[T]:
def topological_sort(graph: dict[T, set[T]]) -> list[T]:
"""
Performs a topological sort of a graph.
@ -29,9 +30,9 @@ def topological_sort(graph: Dict[T, Set[T]]) -> List[T]:
"""
# empty list that will contain the sorted nodes (in reverse order)
ordered: List[T] = []
ordered: list[T] = []
seen: Dict[T, bool] = {}
seen: dict[T, bool] = {}
def _visit(n: T) -> None:
status = seen.get(n)
@ -57,8 +58,8 @@ def topological_sort(graph: Dict[T, Set[T]]) -> List[T]:
def type_topological_sort(
types: Iterable[type],
dependency_fn: Optional[Callable[[type], Iterable[type]]] = None,
) -> List[type]:
dependency_fn: Callable[[type], Iterable[type]] | None = None,
) -> list[type]:
"""
Performs a topological sort of a list of types.
@ -78,7 +79,7 @@ def type_topological_sort(
graph = collector.graph
if dependency_fn:
new_types: Set[type] = set()
new_types: set[type] = set()
for source_type, references in graph.items():
dependent_types = dependency_fn(source_type)
references.update(dependent_types)

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .ollama import get_distribution_template # noqa: F401

View file

@ -0,0 +1,39 @@
version: 2
distribution_spec:
description: Use (an external) Ollama server for running LLM inference
providers:
inference:
- remote::ollama
vector_io:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
files:
- inline::localfs
post_training:
- inline::huggingface
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::rag-runtime
- remote::model-context-protocol
- remote::wolfram-alpha
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -0,0 +1,168 @@
# Ollama Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} {{ model.doc_string }}`
{% endfor %}
{% endif %}
## Prerequisites
### Ollama Server
This distribution requires an external Ollama server to be running. You can install and run Ollama by following these steps:
1. **Install Ollama**: Download and install Ollama from [https://ollama.ai/](https://ollama.ai/)
2. **Start the Ollama server**:
```bash
ollama serve
```
By default, Ollama serves on `http://127.0.0.1:11434`
3. **Pull the required models**:
```bash
# Pull the inference model
ollama pull meta-llama/Llama-3.2-3B-Instruct
# Pull the embedding model
ollama pull all-minilm:latest
# (Optional) Pull the safety model for run-with-safety.yaml
ollama pull meta-llama/Llama-Guard-3-1B
```
## Supported Services
### Inference: Ollama
Uses an external Ollama server for running LLM inference. The server should be accessible at the URL specified in the `OLLAMA_URL` environment variable.
### Vector IO: FAISS
Provides vector storage capabilities using FAISS for embeddings and similarity search operations.
### Safety: Llama Guard (Optional)
When using the `run-with-safety.yaml` configuration, provides safety checks using Llama Guard models running on the Ollama server.
### Agents: Meta Reference
Provides agent execution capabilities using the meta-reference implementation.
### Post-Training: Hugging Face
Supports model fine-tuning using Hugging Face integration.
### Tool Runtime
Supports various external tools including:
- Brave Search
- Tavily Search
- RAG Runtime
- Model Context Protocol
- Wolfram Alpha
## Running Llama Stack with Ollama
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Via Conda
```bash
llama stack build --template ollama --image-type conda
llama stack run ./run.yaml \
--port 8321 \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
llama stack build --template ollama --image-type venv
llama stack run ./run.yaml \
--port 8321 \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Running with Safety
To enable safety checks, use the `run-with-safety.yaml` configuration:
```bash
llama stack run ./run-with-safety.yaml \
--port 8321 \
--env OLLAMA_URL=$OLLAMA_URL \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env SAFETY_MODEL=$SAFETY_MODEL
```
## Example Usage
Once your Llama Stack server is running with Ollama, you can interact with it using the Llama Stack client:
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:8321")
# Run inference
response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.2-3B-Instruct",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)
print(response.completion_message.content)
```
## Troubleshooting
### Common Issues
1. **Connection refused errors**: Ensure your Ollama server is running and accessible at the configured URL.
2. **Model not found errors**: Make sure you've pulled the required models using `ollama pull <model-name>`.
3. **Performance issues**: Consider using more powerful models or adjusting the Ollama server configuration for better performance.
### Logs
Check the Ollama server logs for any issues:
```bash
# Ollama logs are typically available in:
# - macOS: ~/Library/Logs/Ollama/
# - Linux: ~/.ollama/logs/
```

View file

@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::ollama"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"files": ["inline::localfs"],
"post_training": ["inline::huggingface"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::rag-runtime",
"remote::model-context-protocol",
"remote::wolfram-alpha",
],
}
name = "ollama"
inference_provider = Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=OllamaImplConfig.sample_run_config(),
)
vector_io_provider_faiss = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(
f"${{env.XDG_STATE_HOME:-~/.local/state}}/llama-stack/distributions/{name}"
),
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(
f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/distributions/{name}"
),
)
posttraining_provider = Provider(
provider_id="huggingface",
provider_type="inline::huggingface",
config=HuggingFacePostTrainingConfig.sample_run_config(
f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/distributions/{name}"
),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="ollama",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="ollama",
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="ollama",
provider_model_id="all-minilm:latest",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",
),
]
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Use (an external) Ollama server for running LLM inference",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider_faiss],
"files": [files_provider],
"post_training": [posttraining_provider],
},
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider_faiss],
"files": [files_provider],
"post_training": [posttraining_provider],
"safety": [
Provider(
provider_id="llama-guard",
provider_type="inline::llama-guard",
config={},
),
Provider(
provider_id="code-scanner",
provider_type="inline::code-scanner",
config={},
),
],
},
default_models=[
inference_model,
safety_model,
embedding_model,
],
default_shields=[
ShieldInput(
shield_id="${env.SAFETY_MODEL}",
provider_id="llama-guard",
),
ShieldInput(
shield_id="CodeScanner",
provider_id="code-scanner",
),
],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
"OLLAMA_URL": (
"http://127.0.0.1:11434",
"URL of the Ollama server",
),
"INFERENCE_MODEL": (
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model loaded into the Ollama server",
),
"SAFETY_MODEL": (
"meta-llama/Llama-Guard-3-1B",
"Safety model loaded into the Ollama server",
),
},
)

View file

@ -0,0 +1,158 @@
version: 2
image_name: ollama
apis:
- agents
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ollama
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
- provider_id: code-scanner
provider_type: inline::code-scanner
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama}/files_metadata.db
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:=}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: ollama
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: ollama
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ollama
provider_model_id: all-minilm:latest
model_type: embedding
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: llama-guard
- shield_id: CodeScanner
provider_id: code-scanner
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server:
port: 8321

View file

@ -0,0 +1,148 @@
version: 2
image_name: ollama
apis:
- agents
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ollama
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama}/files_metadata.db
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:=}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: ollama
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ollama
provider_model_id: all-minilm:latest
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server:
port: 8321

View file

@ -70,11 +70,44 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
default_value = field.default_factory()
# HACK ALERT:
# If the default value contains a path that looks like it came from RUNTIME_BASE_DIR,
# replace it with a generic ~/.llama/ path for documentation
if isinstance(default_value, str) and "/.llama/" in default_value:
if ".llama/" in default_value:
# replace it with a generic XDG-compliant path for documentation
if isinstance(default_value, str):
# Handle legacy .llama/ paths
if "/.llama/" in default_value:
path_part = default_value.split(".llama/")[-1]
default_value = f"~/.llama/{path_part}"
# Use appropriate XDG directory based on path content
if path_part.startswith("runtime/"):
default_value = f"${{env.XDG_STATE_HOME:-~/.local/state}}/llama-stack/{path_part}"
else:
default_value = f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/{path_part}"
# Handle XDG state paths (runtime data)
elif "/llama-stack/runtime/" in default_value:
path_part = default_value.split("/llama-stack/runtime/")[-1]
default_value = (
f"${{env.XDG_STATE_HOME:-~/.local/state}}/llama-stack/runtime/{path_part}"
)
# Handle XDG data paths
elif "/llama-stack/data/" in default_value or "/llama-stack/checkpoints/" in default_value:
if "/llama-stack/data/" in default_value:
path_part = default_value.split("/llama-stack/data/")[-1]
default_value = (
f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/data/{path_part}"
)
else:
path_part = default_value.split("/llama-stack/checkpoints/")[-1]
default_value = (
f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/checkpoints/{path_part}"
)
# Handle XDG config paths
elif "/llama-stack/" in default_value and (
"/config/" in default_value or "/distributions/" in default_value
):
if "/config/" in default_value:
path_part = default_value.split("/llama-stack/")[-1]
default_value = f"${{env.XDG_CONFIG_HOME:-~/.config}}/llama-stack/{path_part}"
else:
path_part = default_value.split("/llama-stack/")[-1]
default_value = f"${{env.XDG_CONFIG_HOME:-~/.config}}/llama-stack/{path_part}"
except Exception:
default_value = ""
elif field.default is None:
@ -201,7 +234,9 @@ def generate_provider_docs(provider_spec: Any, api_name: str) -> str:
if sample_config_func is not None:
sig = inspect.signature(sample_config_func)
if "__distro_dir__" in sig.parameters:
sample_config = sample_config_func(__distro_dir__="~/.llama/dummy")
sample_config = sample_config_func(
__distro_dir__="${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy"
)
else:
sample_config = sample_config_func()

View file

@ -0,0 +1,593 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import yaml
from llama_stack.distribution.utils.xdg_utils import (
get_llama_stack_config_dir,
get_llama_stack_data_dir,
get_llama_stack_state_dir,
)
class TestXDGEndToEnd(unittest.TestCase):
"""End-to-end tests for XDG compliance workflows."""
def setUp(self):
"""Set up test environment."""
self.original_env = {}
self.env_vars = [
"XDG_CONFIG_HOME",
"XDG_DATA_HOME",
"XDG_STATE_HOME",
"XDG_CACHE_HOME",
"LLAMA_STACK_CONFIG_DIR",
"SQLITE_STORE_DIR",
"FILES_STORAGE_DIR",
]
for key in self.env_vars:
self.original_env[key] = os.environ.get(key)
def tearDown(self):
"""Clean up test environment."""
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def clear_env_vars(self):
"""Clear all relevant environment variables."""
for key in self.env_vars:
os.environ.pop(key, None)
def create_realistic_legacy_structure(self, base_dir: Path) -> Path:
"""Create a realistic legacy ~/.llama directory structure."""
legacy_dir = base_dir / ".llama"
legacy_dir.mkdir()
# Create distributions with realistic content
distributions_dir = legacy_dir / "distributions"
distributions_dir.mkdir()
# Ollama distribution
ollama_dir = distributions_dir / "ollama"
ollama_dir.mkdir()
ollama_run_yaml = ollama_dir / "ollama-run.yaml"
ollama_run_yaml.write_text("""
version: 2
apis:
- inference
- safety
- memory
- vector_io
- agents
- files
providers:
inference:
- provider_type: remote::ollama
config:
url: http://localhost:11434
""")
ollama_build_yaml = ollama_dir / "build.yaml"
ollama_build_yaml.write_text("""
name: ollama
description: Ollama inference provider
docker_image: ollama:latest
""")
# Create providers.d structure
providers_dir = legacy_dir / "providers.d"
providers_dir.mkdir()
remote_dir = providers_dir / "remote"
remote_dir.mkdir()
inference_dir = remote_dir / "inference"
inference_dir.mkdir()
custom_provider = inference_dir / "custom-inference.yaml"
custom_provider.write_text("""
provider_type: remote::custom
config:
url: http://localhost:8080
api_key: test_key
""")
# Create checkpoints with model files
checkpoints_dir = legacy_dir / "checkpoints"
checkpoints_dir.mkdir()
model_dir = checkpoints_dir / "meta-llama" / "Llama-3.2-1B-Instruct"
model_dir.mkdir(parents=True)
# Create fake model files
(model_dir / "consolidated.00.pth").write_bytes(b"fake model weights" * 1000)
(model_dir / "params.json").write_text('{"dim": 2048, "n_layers": 22}')
(model_dir / "tokenizer.model").write_bytes(b"fake tokenizer" * 100)
# Create runtime with databases
runtime_dir = legacy_dir / "runtime"
runtime_dir.mkdir()
(runtime_dir / "trace_store.db").write_text("SQLite format 3\x00" + "fake database content")
(runtime_dir / "agent_sessions.db").write_text("SQLite format 3\x00" + "fake agent sessions")
# Create config files
(legacy_dir / "config.json").write_text('{"version": "0.2.13", "last_updated": "2024-01-01"}')
return legacy_dir
def verify_xdg_migration_complete(self, base_dir: Path, legacy_dir: Path):
"""Verify that migration to XDG structure is complete and correct."""
config_dir = base_dir / ".config" / "llama-stack"
data_dir = base_dir / ".local" / "share" / "llama-stack"
state_dir = base_dir / ".local" / "state" / "llama-stack"
# Verify distributions moved to config
self.assertTrue((config_dir / "distributions").exists())
self.assertTrue((config_dir / "distributions" / "ollama").exists())
self.assertTrue((config_dir / "distributions" / "ollama" / "ollama-run.yaml").exists())
# Verify YAML content is preserved
yaml_content = (config_dir / "distributions" / "ollama" / "ollama-run.yaml").read_text()
self.assertIn("version: 2", yaml_content)
self.assertIn("remote::ollama", yaml_content)
# Verify providers.d moved to config
self.assertTrue((config_dir / "providers.d").exists())
self.assertTrue((config_dir / "providers.d" / "remote" / "inference").exists())
self.assertTrue((config_dir / "providers.d" / "remote" / "inference" / "custom-inference.yaml").exists())
# Verify checkpoints moved to data
self.assertTrue((data_dir / "checkpoints").exists())
self.assertTrue((data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct").exists())
self.assertTrue(
(data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct" / "consolidated.00.pth").exists()
)
# Verify model file content preserved
model_file = data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct" / "consolidated.00.pth"
self.assertGreater(model_file.stat().st_size, 1000) # Should be substantial size
# Verify runtime moved to state
self.assertTrue((state_dir / "runtime").exists())
self.assertTrue((state_dir / "runtime" / "trace_store.db").exists())
self.assertTrue((state_dir / "runtime" / "agent_sessions.db").exists())
# Verify database files preserved
db_file = state_dir / "runtime" / "trace_store.db"
db_content = db_file.read_text()
self.assertIn("SQLite format 3", db_content)
def test_fresh_installation_xdg_compliance(self):
"""Test fresh installation uses XDG-compliant paths."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
# Set custom XDG paths
os.environ["XDG_CONFIG_HOME"] = str(base_dir / "custom_config")
os.environ["XDG_DATA_HOME"] = str(base_dir / "custom_data")
os.environ["XDG_STATE_HOME"] = str(base_dir / "custom_state")
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Mock that no legacy directory exists
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
# Fresh installation should use XDG paths
config_dir = get_llama_stack_config_dir()
data_dir = get_llama_stack_data_dir()
state_dir = get_llama_stack_state_dir()
self.assertEqual(config_dir, base_dir / "custom_config" / "llama-stack")
self.assertEqual(data_dir, base_dir / "custom_data" / "llama-stack")
self.assertEqual(state_dir, base_dir / "custom_state" / "llama-stack")
def test_complete_migration_workflow(self):
"""Test complete migration workflow from legacy to XDG."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create realistic legacy structure
legacy_dir = self.create_realistic_legacy_structure(base_dir)
# Verify legacy structure exists
self.assertTrue(legacy_dir.exists())
self.assertTrue((legacy_dir / "distributions" / "ollama" / "ollama-run.yaml").exists())
# Perform migration
from llama_stack.cli.migrate_xdg import migrate_to_xdg
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Verify migration completed successfully
self.verify_xdg_migration_complete(base_dir, legacy_dir)
# Verify legacy directory was removed
self.assertFalse(legacy_dir.exists())
def test_migration_preserves_file_integrity(self):
"""Test that migration preserves file integrity and permissions."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure
legacy_dir = self.create_realistic_legacy_structure(base_dir)
# Set specific permissions on files
config_file = legacy_dir / "distributions" / "ollama" / "ollama-run.yaml"
config_file.chmod(0o600)
original_stat = config_file.stat()
# Perform migration
from llama_stack.cli.migrate_xdg import migrate_to_xdg
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"]
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Verify file integrity
migrated_config = base_dir / ".config" / "llama-stack" / "distributions" / "ollama" / "ollama-run.yaml"
self.assertTrue(migrated_config.exists())
# Verify content is identical
migrated_content = migrated_config.read_text()
self.assertIn("version: 2", migrated_content)
self.assertIn("remote::ollama", migrated_content)
# Verify permissions preserved
migrated_stat = migrated_config.stat()
self.assertEqual(original_stat.st_mode, migrated_stat.st_mode)
def test_mixed_legacy_and_xdg_environment(self):
"""Test behavior in mixed legacy and XDG environment."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
# Set partial XDG environment
os.environ["XDG_CONFIG_HOME"] = str(base_dir / "xdg_config")
# Leave XDG_DATA_HOME and XDG_STATE_HOME unset
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy directory
legacy_dir = self.create_realistic_legacy_structure(base_dir)
# Should use legacy directory since it exists
config_dir = get_llama_stack_config_dir()
data_dir = get_llama_stack_data_dir()
state_dir = get_llama_stack_state_dir()
self.assertEqual(config_dir, legacy_dir)
self.assertEqual(data_dir, legacy_dir)
self.assertEqual(state_dir, legacy_dir)
def test_template_rendering_with_xdg_paths(self):
"""Test that templates render correctly with XDG paths."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
# Set XDG environment
os.environ["XDG_STATE_HOME"] = str(base_dir / "state")
os.environ["XDG_DATA_HOME"] = str(base_dir / "data")
# Mock shell environment variable expansion
def mock_env_expand(template_string):
"""Mock shell environment variable expansion."""
result = template_string
result = result.replace("${env.XDG_STATE_HOME:-~/.local/state}", str(base_dir / "state"))
result = result.replace("${env.XDG_DATA_HOME:-~/.local/share}", str(base_dir / "data"))
return result
# Test template patterns
template_patterns = [
"${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama",
"${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files",
]
expected_results = [
str(base_dir / "state" / "llama-stack" / "distributions" / "ollama"),
str(base_dir / "data" / "llama-stack" / "distributions" / "ollama" / "files"),
]
for pattern, expected in zip(template_patterns, expected_results, strict=False):
with self.subTest(pattern=pattern):
expanded = mock_env_expand(pattern)
self.assertEqual(expanded, expected)
def test_cli_integration_with_xdg_paths(self):
"""Test CLI integration works correctly with XDG paths."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure
legacy_dir = self.create_realistic_legacy_structure(base_dir)
# Test CLI migrate command
import argparse
from llama_stack.cli.migrate_xdg import MigrateXDG
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
migrate_cmd = MigrateXDG.create(subparsers)
# Test dry-run
args = parser.parse_args(["migrate-xdg", "--dry-run"])
with patch("builtins.print") as mock_print:
result = migrate_cmd._run_migrate_xdg_cmd(args)
self.assertEqual(result, 0)
# Should print dry-run information
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Dry run mode" in call for call in print_calls))
# Legacy directory should still exist after dry-run
self.assertTrue(legacy_dir.exists())
def test_config_dirs_integration_after_migration(self):
"""Test that config_dirs works correctly after migration."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create and migrate legacy structure
self.create_realistic_legacy_structure(base_dir)
from llama_stack.cli.migrate_xdg import migrate_to_xdg
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"]
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Clear module cache to ensure fresh import
import sys
if "llama_stack.distribution.utils.config_dirs" in sys.modules:
del sys.modules["llama_stack.distribution.utils.config_dirs"]
# Import config_dirs after migration
from llama_stack.distribution.utils.config_dirs import (
DEFAULT_CHECKPOINT_DIR,
DISTRIBS_BASE_DIR,
LLAMA_STACK_CONFIG_DIR,
RUNTIME_BASE_DIR,
)
# Should use XDG paths
self.assertEqual(LLAMA_STACK_CONFIG_DIR, base_dir / ".config" / "llama-stack")
self.assertEqual(DEFAULT_CHECKPOINT_DIR, base_dir / ".local" / "share" / "llama-stack" / "checkpoints")
self.assertEqual(RUNTIME_BASE_DIR, base_dir / ".local" / "state" / "llama-stack" / "runtime")
self.assertEqual(DISTRIBS_BASE_DIR, base_dir / ".config" / "llama-stack" / "distributions")
def test_real_file_operations_with_xdg_paths(self):
"""Test real file operations work correctly with XDG paths."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
# Set XDG environment
os.environ["XDG_CONFIG_HOME"] = str(base_dir / "config")
os.environ["XDG_DATA_HOME"] = str(base_dir / "data")
os.environ["XDG_STATE_HOME"] = str(base_dir / "state")
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
# Get XDG paths
config_dir = get_llama_stack_config_dir()
data_dir = get_llama_stack_data_dir()
state_dir = get_llama_stack_state_dir()
# Create directories
config_dir.mkdir(parents=True)
data_dir.mkdir(parents=True)
state_dir.mkdir(parents=True)
# Test writing configuration files
config_file = config_dir / "test_config.yaml"
config_data = {"version": "2", "test": True}
with open(config_file, "w") as f:
yaml.dump(config_data, f)
# Test reading configuration files
with open(config_file) as f:
loaded_config = yaml.safe_load(f)
self.assertEqual(loaded_config, config_data)
# Test creating nested directory structure
model_dir = data_dir / "checkpoints" / "meta-llama" / "test-model"
model_dir.mkdir(parents=True)
# Test writing large files
model_file = model_dir / "model.bin"
test_data = b"test model data" * 1000
model_file.write_bytes(test_data)
# Verify file integrity
read_data = model_file.read_bytes()
self.assertEqual(read_data, test_data)
# Test state files
state_file = state_dir / "runtime" / "session.db"
state_file.parent.mkdir(parents=True)
state_file.write_text("SQLite format 3\x00test database")
# Verify state file
state_content = state_file.read_text()
self.assertIn("SQLite format 3", state_content)
def test_backwards_compatibility_scenario(self):
"""Test complete backwards compatibility scenario."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
# Scenario: User has existing legacy installation
legacy_dir = self.create_realistic_legacy_structure(base_dir)
# User sets legacy environment variable
os.environ["LLAMA_STACK_CONFIG_DIR"] = str(legacy_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Should continue using legacy paths
config_dir = get_llama_stack_config_dir()
data_dir = get_llama_stack_data_dir()
state_dir = get_llama_stack_state_dir()
self.assertEqual(config_dir, legacy_dir)
self.assertEqual(data_dir, legacy_dir)
self.assertEqual(state_dir, legacy_dir)
# Should be able to access existing files
yaml_file = legacy_dir / "distributions" / "ollama" / "ollama-run.yaml"
self.assertTrue(yaml_file.exists())
# Should be able to parse existing configuration
with open(yaml_file) as f:
config = yaml.safe_load(f)
self.assertEqual(config["version"], 2)
def test_error_recovery_scenarios(self):
"""Test error recovery in various scenarios."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Scenario 1: Partial migration failure
self.create_realistic_legacy_structure(base_dir)
# Create conflicting file in target location
config_dir = base_dir / ".config" / "llama-stack"
config_dir.mkdir(parents=True)
conflicting_file = config_dir / "distributions"
conflicting_file.touch() # Create file instead of directory
from llama_stack.cli.migrate_xdg import migrate_to_xdg
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup
with patch("builtins.print") as mock_print:
result = migrate_to_xdg(dry_run=False)
# Should handle conflicts gracefully
print_calls = [call[0][0] for call in mock_print.call_args_list]
conflict_mentioned = any(
"Warning" in call or "conflict" in call.lower() for call in print_calls
)
# Migration should complete with warnings
self.assertTrue(result or conflict_mentioned)
def test_cross_platform_compatibility(self):
"""Test cross-platform compatibility of XDG implementation."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
# Test with different path separators and formats
if os.name == "nt": # Windows
# Test Windows-style paths
os.environ["XDG_CONFIG_HOME"] = str(base_dir / "config").replace("/", "\\")
os.environ["XDG_DATA_HOME"] = str(base_dir / "data").replace("/", "\\")
else: # Unix-like
# Test Unix-style paths
os.environ["XDG_CONFIG_HOME"] = str(base_dir / "config")
os.environ["XDG_DATA_HOME"] = str(base_dir / "data")
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
# Should work regardless of platform
config_dir = get_llama_stack_config_dir()
data_dir = get_llama_stack_data_dir()
# Paths should be valid for the current platform
self.assertTrue(config_dir.is_absolute())
self.assertTrue(data_dir.is_absolute())
# Should be able to create directories
config_dir.mkdir(parents=True, exist_ok=True)
data_dir.mkdir(parents=True, exist_ok=True)
self.assertTrue(config_dir.exists())
self.assertTrue(data_dir.exists())
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,516 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from llama_stack.cli.migrate_xdg import migrate_to_xdg
class TestXDGMigrationIntegration(unittest.TestCase):
"""Integration tests for XDG migration functionality."""
def setUp(self):
"""Set up test environment."""
# Store original environment variables
self.original_env = {}
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
self.original_env[key] = os.environ.get(key)
# Clear environment variables
for key in self.original_env:
os.environ.pop(key, None)
def tearDown(self):
"""Clean up test environment."""
# Restore original environment variables
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def create_legacy_structure(self, base_dir: Path) -> Path:
"""Create a realistic legacy ~/.llama directory structure."""
legacy_dir = base_dir / ".llama"
legacy_dir.mkdir()
# Create distributions
distributions_dir = legacy_dir / "distributions"
distributions_dir.mkdir()
# Create sample distribution
ollama_dir = distributions_dir / "ollama"
ollama_dir.mkdir()
(ollama_dir / "ollama-run.yaml").write_text("version: 2\napis: []\n")
(ollama_dir / "build.yaml").write_text("name: ollama\n")
# Create providers.d
providers_dir = legacy_dir / "providers.d"
providers_dir.mkdir()
(providers_dir / "remote").mkdir()
(providers_dir / "remote" / "inference").mkdir()
(providers_dir / "remote" / "inference" / "custom.yaml").write_text("provider_type: remote::custom\n")
# Create checkpoints
checkpoints_dir = legacy_dir / "checkpoints"
checkpoints_dir.mkdir()
model_dir = checkpoints_dir / "meta-llama" / "Llama-3.2-1B-Instruct"
model_dir.mkdir(parents=True)
(model_dir / "consolidated.00.pth").write_text("fake model weights")
(model_dir / "params.json").write_text('{"dim": 2048}')
# Create runtime
runtime_dir = legacy_dir / "runtime"
runtime_dir.mkdir()
(runtime_dir / "trace_store.db").write_text("fake sqlite database")
# Create some fake files in various subdirectories
(legacy_dir / "config.json").write_text('{"version": "0.2.13"}')
return legacy_dir
def verify_xdg_structure(self, base_dir: Path, legacy_dir: Path):
"""Verify that the XDG structure was created correctly."""
config_dir = base_dir / ".config" / "llama-stack"
data_dir = base_dir / ".local" / "share" / "llama-stack"
state_dir = base_dir / ".local" / "state" / "llama-stack"
# Verify distributions moved to config
self.assertTrue((config_dir / "distributions").exists())
self.assertTrue((config_dir / "distributions" / "ollama" / "ollama-run.yaml").exists())
# Verify providers.d moved to config
self.assertTrue((config_dir / "providers.d").exists())
self.assertTrue((config_dir / "providers.d" / "remote" / "inference" / "custom.yaml").exists())
# Verify checkpoints moved to data
self.assertTrue((data_dir / "checkpoints").exists())
self.assertTrue(
(data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct" / "consolidated.00.pth").exists()
)
# Verify runtime moved to state
self.assertTrue((state_dir / "runtime").exists())
self.assertTrue((state_dir / "runtime" / "trace_store.db").exists())
def test_full_migration_workflow(self):
"""Test complete migration workflow from legacy to XDG."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
# Set up fake home directory
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure
legacy_dir = self.create_legacy_structure(base_dir)
# Verify legacy structure exists
self.assertTrue(legacy_dir.exists())
self.assertTrue((legacy_dir / "distributions").exists())
self.assertTrue((legacy_dir / "checkpoints").exists())
# Perform migration with user confirmation
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Verify XDG structure was created
self.verify_xdg_structure(base_dir, legacy_dir)
# Verify legacy directory was removed
self.assertFalse(legacy_dir.exists())
def test_migration_dry_run(self):
"""Test dry run migration (no actual file movement)."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure
legacy_dir = self.create_legacy_structure(base_dir)
# Perform dry run migration
with patch("builtins.print") as mock_print:
result = migrate_to_xdg(dry_run=True)
self.assertTrue(result)
# Check that dry run message was printed
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Dry run mode" in call for call in print_calls))
# Verify nothing was actually moved
self.assertTrue(legacy_dir.exists())
self.assertTrue((legacy_dir / "distributions").exists())
self.assertFalse((base_dir / ".config" / "llama-stack").exists())
def test_migration_user_cancellation(self):
"""Test migration when user cancels."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure
legacy_dir = self.create_legacy_structure(base_dir)
# User cancels migration
with patch("builtins.input") as mock_input:
mock_input.return_value = "n"
result = migrate_to_xdg(dry_run=False)
self.assertFalse(result)
# Verify nothing was moved
self.assertTrue(legacy_dir.exists())
self.assertTrue((legacy_dir / "distributions").exists())
self.assertFalse((base_dir / ".config" / "llama-stack").exists())
def test_migration_with_existing_xdg_directories(self):
"""Test migration when XDG directories already exist."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure
legacy_dir = self.create_legacy_structure(base_dir)
# Create existing XDG structure with conflicting files
config_dir = base_dir / ".config" / "llama-stack"
config_dir.mkdir(parents=True)
(config_dir / "distributions").mkdir()
(config_dir / "distributions" / "existing.yaml").write_text("existing config")
# Perform migration
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup
with patch("builtins.print") as mock_print:
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Check that warning was printed
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Warning: Target already exists" in call for call in print_calls))
# Verify existing file wasn't overwritten
self.assertTrue((config_dir / "distributions" / "existing.yaml").exists())
# Legacy distributions should still exist due to conflict
self.assertTrue((legacy_dir / "distributions").exists())
def test_migration_partial_success(self):
"""Test migration when some items succeed and others fail."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure
self.create_legacy_structure(base_dir)
# Create readonly target directory to simulate permission error
config_dir = base_dir / ".config" / "llama-stack"
config_dir.mkdir(parents=True)
distributions_target = config_dir / "distributions"
distributions_target.mkdir()
distributions_target.chmod(0o444) # Read-only
try:
# Perform migration
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup
migrate_to_xdg(dry_run=False)
# Should return True even with partial success
# Some items should have been migrated successfully
self.assertTrue((base_dir / ".local" / "share" / "llama-stack" / "checkpoints").exists())
finally:
# Restore permissions for cleanup
distributions_target.chmod(0o755)
def test_migration_empty_legacy_directory(self):
"""Test migration when legacy directory exists but is empty."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create empty legacy directory
legacy_dir = base_dir / ".llama"
legacy_dir.mkdir()
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
def test_migration_preserves_file_permissions(self):
"""Test that migration preserves file permissions."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure with specific permissions
legacy_dir = self.create_legacy_structure(base_dir)
# Set specific permissions on a file
config_file = legacy_dir / "distributions" / "ollama" / "ollama-run.yaml"
config_file.chmod(0o600)
original_stat = config_file.stat()
# Perform migration
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"]
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Verify permissions were preserved
migrated_file = base_dir / ".config" / "llama-stack" / "distributions" / "ollama" / "ollama-run.yaml"
self.assertTrue(migrated_file.exists())
migrated_stat = migrated_file.stat()
self.assertEqual(original_stat.st_mode, migrated_stat.st_mode)
def test_migration_preserves_directory_structure(self):
"""Test that migration preserves complex directory structures."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create complex legacy structure
legacy_dir = base_dir / ".llama"
legacy_dir.mkdir()
# Create nested structure
complex_path = legacy_dir / "checkpoints" / "org" / "model" / "variant" / "files"
complex_path.mkdir(parents=True)
(complex_path / "model.bin").write_text("model data")
(complex_path / "config.json").write_text("config data")
# Perform migration
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"]
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Verify structure was preserved
migrated_path = (
base_dir
/ ".local"
/ "share"
/ "llama-stack"
/ "checkpoints"
/ "org"
/ "model"
/ "variant"
/ "files"
)
self.assertTrue(migrated_path.exists())
self.assertTrue((migrated_path / "model.bin").exists())
self.assertTrue((migrated_path / "config.json").exists())
def test_migration_with_symlinks(self):
"""Test migration with symbolic links in legacy directory."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure
legacy_dir = self.create_legacy_structure(base_dir)
# Create a symlink
actual_file = legacy_dir / "actual_config.yaml"
actual_file.write_text("actual config content")
symlink_file = legacy_dir / "distributions" / "symlinked_config.yaml"
symlink_file.symlink_to(actual_file)
# Perform migration
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"]
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Verify symlink was preserved
migrated_symlink = base_dir / ".config" / "llama-stack" / "distributions" / "symlinked_config.yaml"
self.assertTrue(migrated_symlink.exists())
self.assertTrue(migrated_symlink.is_symlink())
def test_migration_large_files(self):
"""Test migration with large files (simulated)."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure
legacy_dir = self.create_legacy_structure(base_dir)
# Create a larger file (1MB)
large_file = legacy_dir / "checkpoints" / "large_model.bin"
large_file.write_bytes(b"0" * (1024 * 1024))
# Perform migration
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"]
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Verify large file was moved correctly
migrated_file = base_dir / ".local" / "share" / "llama-stack" / "checkpoints" / "large_model.bin"
self.assertTrue(migrated_file.exists())
self.assertEqual(migrated_file.stat().st_size, 1024 * 1024)
def test_migration_with_unicode_filenames(self):
"""Test migration with unicode filenames."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy structure with unicode filenames
legacy_dir = base_dir / ".llama"
legacy_dir.mkdir()
unicode_dir = legacy_dir / "distributions" / "配置"
unicode_dir.mkdir(parents=True)
unicode_file = unicode_dir / "模型.yaml"
unicode_file.write_text("unicode content")
# Perform migration
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"]
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Verify unicode files were migrated
migrated_file = base_dir / ".config" / "llama-stack" / "distributions" / "配置" / "模型.yaml"
self.assertTrue(migrated_file.exists())
self.assertEqual(migrated_file.read_text(), "unicode content")
class TestXDGMigrationCLI(unittest.TestCase):
"""Test the CLI interface for XDG migration."""
def setUp(self):
"""Set up test environment."""
self.original_env = {}
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
self.original_env[key] = os.environ.get(key)
os.environ.pop(key, None)
def tearDown(self):
"""Clean up test environment."""
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def test_cli_migrate_command_exists(self):
"""Test that the migrate-xdg CLI command is properly registered."""
from llama_stack.cli.llama import LlamaCLIParser
parser = LlamaCLIParser()
# Parse help to see if migrate-xdg is listed
with patch("sys.argv", ["llama", "--help"]):
with patch("sys.exit"):
with patch("builtins.print") as mock_print:
try:
parser.parse_args()
except SystemExit:
pass
# Check if migrate-xdg appears in help output
help_output = "\n".join([call[0][0] for call in mock_print.call_args_list])
self.assertIn("migrate-xdg", help_output)
def test_cli_migrate_dry_run(self):
"""Test CLI migrate command with dry-run flag."""
import argparse
from llama_stack.cli.migrate_xdg import MigrateXDG
# Create parser and add migrate command
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
MigrateXDG.create(subparsers)
# Test dry-run flag
args = parser.parse_args(["migrate-xdg", "--dry-run"])
self.assertTrue(args.dry_run)
# Test without dry-run flag
args = parser.parse_args(["migrate-xdg"])
self.assertFalse(args.dry_run)
def test_cli_migrate_execution(self):
"""Test CLI migrate command execution."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create legacy directory
legacy_dir = base_dir / ".llama"
legacy_dir.mkdir()
(legacy_dir / "test_file").touch()
import argparse
from llama_stack.cli.migrate_xdg import MigrateXDG
# Create parser and command
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
migrate_cmd = MigrateXDG.create(subparsers)
# Parse arguments
args = parser.parse_args(["migrate-xdg", "--dry-run"])
# Execute command
with patch("builtins.print") as mock_print:
migrate_cmd._run_migrate_xdg_cmd(args)
# Verify output was printed
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Found legacy directory" in call for call in print_calls))
if __name__ == "__main__":
unittest.main()

164
tests/run_xdg_tests.py Executable file
View file

@ -0,0 +1,164 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Test runner for XDG Base Directory Specification compliance tests.
This script runs all XDG-related tests and provides a comprehensive report
of the test results.
"""
import sys
import unittest
from pathlib import Path
def run_test_suite():
"""Run the complete XDG test suite."""
# Set up test environment
test_dir = Path(__file__).parent
project_root = test_dir.parent
# Add project root to Python path
sys.path.insert(0, str(project_root))
# Test modules to run
test_modules = [
"tests.unit.test_xdg_compliance",
"tests.unit.test_config_dirs",
"tests.unit.cli.test_migrate_xdg",
"tests.unit.test_template_xdg_paths",
"tests.unit.test_xdg_edge_cases",
"tests.integration.test_xdg_migration",
"tests.integration.test_xdg_e2e",
]
# Discover and run tests
loader = unittest.TestLoader()
suite = unittest.TestSuite()
print("🔍 Discovering XDG compliance tests...")
for module_name in test_modules:
try:
# Try to load the test module
module_suite = loader.loadTestsFromName(module_name)
suite.addTest(module_suite)
print(f" ✅ Loaded {module_name}")
except Exception as e:
print(f" ⚠️ Failed to load {module_name}: {e}")
# Run the tests
print("\n🧪 Running XDG compliance tests...")
print("=" * 60)
runner = unittest.TextTestRunner(verbosity=2, stream=sys.stdout, descriptions=True, buffer=True)
result = runner.run(suite)
# Print summary
print("\n" + "=" * 60)
print("📊 Test Summary:")
print(f" Tests run: {result.testsRun}")
print(f" Failures: {len(result.failures)}")
print(f" Errors: {len(result.errors)}")
print(f" Skipped: {len(result.skipped)}")
if result.failures:
print("\n❌ Failures:")
for test, traceback in result.failures:
print(f" - {test}: {traceback.split('AssertionError:')[-1].strip()}")
if result.errors:
print("\n💥 Errors:")
for test, traceback in result.errors:
print(f" - {test}: {traceback.split('Exception:')[-1].strip()}")
if result.skipped:
print("\n⏭️ Skipped:")
for test, reason in result.skipped:
print(f" - {test}: {reason}")
# Overall result
if result.wasSuccessful():
print("\n🎉 All XDG compliance tests passed!")
return 0
else:
print("\n⚠️ Some XDG compliance tests failed.")
return 1
def run_quick_tests():
"""Run a quick subset of critical XDG tests."""
print("🚀 Running quick XDG compliance tests...")
# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# Quick test: Basic XDG functionality
try:
from llama_stack.distribution.utils.xdg_utils import (
get_llama_stack_config_dir,
get_xdg_compliant_path,
get_xdg_config_home,
)
print(" ✅ XDG utilities import successfully")
# Test basic functionality
config_home = get_xdg_config_home()
llama_config = get_llama_stack_config_dir()
compliant_path = get_xdg_compliant_path("config", "test")
print(f" ✅ XDG config home: {config_home}")
print(f" ✅ Llama Stack config: {llama_config}")
print(f" ✅ Compliant path: {compliant_path}")
except Exception as e:
print(f" ❌ XDG utilities failed: {e}")
return 1
# Quick test: Config dirs integration
try:
from llama_stack.distribution.utils.config_dirs import (
DEFAULT_CHECKPOINT_DIR,
LLAMA_STACK_CONFIG_DIR,
)
print(f" ✅ Config dirs integration: {LLAMA_STACK_CONFIG_DIR}")
print(f" ✅ Checkpoint directory: {DEFAULT_CHECKPOINT_DIR}")
except Exception as e:
print(f" ❌ Config dirs integration failed: {e}")
return 1
# Quick test: CLI migrate command
try:
print(" ✅ CLI migrate command available")
except Exception as e:
print(f" ❌ CLI migrate command failed: {e}")
return 1
print("\n🎉 Quick XDG compliance tests passed!")
return 0
def main():
"""Main test runner entry point."""
if len(sys.argv) > 1 and sys.argv[1] == "--quick":
return run_quick_tests()
else:
return run_test_suite()
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,489 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
import tempfile
import unittest
from io import StringIO
from pathlib import Path
from unittest.mock import patch
from llama_stack.cli.migrate_xdg import MigrateXDG, migrate_to_xdg
class TestMigrateXDGCLI(unittest.TestCase):
"""Test the MigrateXDG CLI command."""
def setUp(self):
"""Set up test environment."""
self.original_env = {}
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
self.original_env[key] = os.environ.get(key)
os.environ.pop(key, None)
def tearDown(self):
"""Clean up test environment."""
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def create_parser_with_migrate_cmd(self):
"""Create parser with migrate-xdg command."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="command")
migrate_cmd = MigrateXDG.create(subparsers)
return parser, migrate_cmd
def test_migrate_xdg_command_creation(self):
"""Test that MigrateXDG command can be created."""
parser, migrate_cmd = self.create_parser_with_migrate_cmd()
self.assertIsInstance(migrate_cmd, MigrateXDG)
self.assertEqual(migrate_cmd.parser.prog, "llama migrate-xdg")
self.assertEqual(migrate_cmd.parser.description, "Migrate from legacy ~/.llama to XDG-compliant directories")
def test_migrate_xdg_argument_parsing(self):
"""Test argument parsing for migrate-xdg command."""
parser, _ = self.create_parser_with_migrate_cmd()
# Test with dry-run flag
args = parser.parse_args(["migrate-xdg", "--dry-run"])
self.assertEqual(args.command, "migrate-xdg")
self.assertTrue(args.dry_run)
# Test without dry-run flag
args = parser.parse_args(["migrate-xdg"])
self.assertEqual(args.command, "migrate-xdg")
self.assertFalse(args.dry_run)
def test_migrate_xdg_help_text(self):
"""Test help text for migrate-xdg command."""
parser, _ = self.create_parser_with_migrate_cmd()
# Capture help output
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
with patch("sys.exit"):
try:
parser.parse_args(["migrate-xdg", "--help"])
except SystemExit:
pass
help_text = mock_stdout.getvalue()
self.assertIn("migrate-xdg", help_text)
self.assertIn("XDG-compliant directories", help_text)
self.assertIn("--dry-run", help_text)
def test_migrate_xdg_command_execution_no_legacy(self):
"""Test command execution when no legacy directory exists."""
parser, migrate_cmd = self.create_parser_with_migrate_cmd()
with tempfile.TemporaryDirectory() as temp_dir:
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = Path(temp_dir)
args = parser.parse_args(["migrate-xdg"])
with patch("builtins.print") as mock_print:
result = migrate_cmd._run_migrate_xdg_cmd(args)
# Should succeed when no migration needed
self.assertEqual(result, 0)
# Should print appropriate message
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("No legacy directory found" in call for call in print_calls))
def test_migrate_xdg_command_execution_with_legacy(self):
"""Test command execution when legacy directory exists."""
parser, migrate_cmd = self.create_parser_with_migrate_cmd()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = base_dir / ".llama"
legacy_dir.mkdir()
(legacy_dir / "test_file").touch()
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
args = parser.parse_args(["migrate-xdg"])
with patch("builtins.print") as mock_print:
result = migrate_cmd._run_migrate_xdg_cmd(args)
# Should succeed
self.assertEqual(result, 0)
# Should print migration information
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Found legacy directory" in call for call in print_calls))
def test_migrate_xdg_command_execution_dry_run(self):
"""Test command execution with dry-run flag."""
parser, migrate_cmd = self.create_parser_with_migrate_cmd()
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = base_dir / ".llama"
legacy_dir.mkdir()
(legacy_dir / "test_file").touch()
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
args = parser.parse_args(["migrate-xdg", "--dry-run"])
with patch("builtins.print") as mock_print:
result = migrate_cmd._run_migrate_xdg_cmd(args)
# Should succeed
self.assertEqual(result, 0)
# Should print dry-run information
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Dry run mode" in call for call in print_calls))
def test_migrate_xdg_command_execution_error_handling(self):
"""Test command execution with error handling."""
parser, migrate_cmd = self.create_parser_with_migrate_cmd()
args = parser.parse_args(["migrate-xdg"])
# Mock migrate_to_xdg to raise an exception
with patch("llama_stack.cli.migrate_xdg.migrate_to_xdg") as mock_migrate:
mock_migrate.side_effect = Exception("Test error")
with patch("builtins.print") as mock_print:
result = migrate_cmd._run_migrate_xdg_cmd(args)
# Should return error code
self.assertEqual(result, 1)
# Should print error message
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Error during migration" in call for call in print_calls))
def test_migrate_xdg_command_integration(self):
"""Test full integration of migrate-xdg command."""
from llama_stack.cli.llama import LlamaCLIParser
# Create main parser
main_parser = LlamaCLIParser()
# Test that migrate-xdg is in the subcommands
with patch("sys.argv", ["llama", "migrate-xdg", "--help"]):
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
with patch("sys.exit"):
try:
main_parser.parse_args()
except SystemExit:
pass
help_text = mock_stdout.getvalue()
self.assertIn("migrate-xdg", help_text)
class TestMigrateXDGFunction(unittest.TestCase):
"""Test the migrate_to_xdg function directly."""
def setUp(self):
"""Set up test environment."""
self.original_env = {}
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
self.original_env[key] = os.environ.get(key)
os.environ.pop(key, None)
def tearDown(self):
"""Clean up test environment."""
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def create_legacy_structure(self, base_dir: Path) -> Path:
"""Create a test legacy directory structure."""
legacy_dir = base_dir / ".llama"
legacy_dir.mkdir()
# Create distributions
(legacy_dir / "distributions").mkdir()
(legacy_dir / "distributions" / "ollama").mkdir()
(legacy_dir / "distributions" / "ollama" / "run.yaml").write_text("version: 2\n")
# Create checkpoints
(legacy_dir / "checkpoints").mkdir()
(legacy_dir / "checkpoints" / "model.bin").write_text("fake model")
# Create providers.d
(legacy_dir / "providers.d").mkdir()
(legacy_dir / "providers.d" / "provider.yaml").write_text("provider: test\n")
# Create runtime
(legacy_dir / "runtime").mkdir()
(legacy_dir / "runtime" / "trace.db").write_text("fake database")
return legacy_dir
def test_migrate_to_xdg_no_legacy_directory(self):
"""Test migrate_to_xdg when no legacy directory exists."""
with tempfile.TemporaryDirectory() as temp_dir:
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = Path(temp_dir)
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
def test_migrate_to_xdg_dry_run(self):
"""Test migrate_to_xdg with dry_run=True."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = self.create_legacy_structure(base_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
with patch("builtins.print") as mock_print:
result = migrate_to_xdg(dry_run=True)
self.assertTrue(result)
# Should print dry run information
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Dry run mode" in call for call in print_calls))
# Legacy directory should still exist
self.assertTrue(legacy_dir.exists())
def test_migrate_to_xdg_user_confirms(self):
"""Test migrate_to_xdg when user confirms migration."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = self.create_legacy_structure(base_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Legacy directory should be removed
self.assertFalse(legacy_dir.exists())
# XDG directories should be created
self.assertTrue((base_dir / ".config" / "llama-stack").exists())
self.assertTrue((base_dir / ".local" / "share" / "llama-stack").exists())
def test_migrate_to_xdg_user_cancels(self):
"""Test migrate_to_xdg when user cancels migration."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = self.create_legacy_structure(base_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
with patch("builtins.input") as mock_input:
mock_input.return_value = "n" # Cancel migration
result = migrate_to_xdg(dry_run=False)
self.assertFalse(result)
# Legacy directory should still exist
self.assertTrue(legacy_dir.exists())
def test_migrate_to_xdg_partial_migration(self):
"""Test migrate_to_xdg with partial migration (some files fail)."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = self.create_legacy_structure(base_dir)
self.assertFalse(legacy_dir.exists())
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create target directory with conflicting file
config_dir = base_dir / ".config" / "llama-stack"
config_dir.mkdir(parents=True)
(config_dir / "distributions").mkdir()
(config_dir / "distributions" / "existing.yaml").write_text("existing")
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup
with patch("builtins.print") as mock_print:
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Should print warning about conflicts
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Warning: Target already exists" in call for call in print_calls))
def test_migrate_to_xdg_permission_error(self):
"""Test migrate_to_xdg with permission errors."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = self.create_legacy_structure(base_dir)
self.assertFalse(legacy_dir.exists())
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Create readonly target directory
config_dir = base_dir / ".config" / "llama-stack"
config_dir.mkdir(parents=True)
config_dir.chmod(0o444) # Read-only
try:
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup
with patch("builtins.print") as mock_print:
result = migrate_to_xdg(dry_run=False)
self.assertFalse(result)
# Should handle permission errors gracefully
print_calls = [call[0][0] for call in mock_print.call_args_list]
# Should contain some error or warning message
self.assertTrue(len(print_calls) > 0)
finally:
# Restore permissions for cleanup
config_dir.chmod(0o755)
def test_migrate_to_xdg_empty_legacy_directory(self):
"""Test migrate_to_xdg with empty legacy directory."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = base_dir / ".llama"
legacy_dir.mkdir() # Empty directory
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
def test_migrate_to_xdg_preserves_file_content(self):
"""Test that migrate_to_xdg preserves file content correctly."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = self.create_legacy_structure(base_dir)
# Add specific content to test
test_content = "test configuration content"
(legacy_dir / "distributions" / "ollama" / "run.yaml").write_text(test_content)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Check content was preserved
migrated_file = base_dir / ".config" / "llama-stack" / "distributions" / "ollama" / "run.yaml"
self.assertTrue(migrated_file.exists())
self.assertEqual(migrated_file.read_text(), test_content)
def test_migrate_to_xdg_with_symlinks(self):
"""Test migrate_to_xdg with symbolic links."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = self.create_legacy_structure(base_dir)
# Create symlink
actual_file = legacy_dir / "actual_config.yaml"
actual_file.write_text("actual config")
symlink_file = legacy_dir / "distributions" / "symlinked.yaml"
symlink_file.symlink_to(actual_file)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Check symlink was preserved
migrated_symlink = base_dir / ".config" / "llama-stack" / "distributions" / "symlinked.yaml"
self.assertTrue(migrated_symlink.exists())
self.assertTrue(migrated_symlink.is_symlink())
def test_migrate_to_xdg_nested_directory_structure(self):
"""Test migrate_to_xdg with nested directory structures."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = self.create_legacy_structure(base_dir)
# Create nested structure
nested_dir = legacy_dir / "checkpoints" / "org" / "model" / "variant"
nested_dir.mkdir(parents=True)
(nested_dir / "model.bin").write_text("nested model")
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
with patch("builtins.input") as mock_input:
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result)
# Check nested structure was preserved
migrated_nested = (
base_dir / ".local" / "share" / "llama-stack" / "checkpoints" / "org" / "model" / "variant"
)
self.assertTrue(migrated_nested.exists())
self.assertTrue((migrated_nested / "model.bin").exists())
def test_migrate_to_xdg_user_input_variations(self):
"""Test migrate_to_xdg with various user input variations."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir)
legacy_dir = self.create_legacy_structure(base_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = base_dir
# Test various forms of "yes"
for yes_input in ["y", "Y", "yes", "Yes", "YES"]:
# Recreate legacy directory for each test
if legacy_dir.exists():
import shutil
shutil.rmtree(legacy_dir)
self.create_legacy_structure(base_dir)
with patch("builtins.input") as mock_input:
mock_input.side_effect = [yes_input, "n"] # Confirm migration, don't cleanup
result = migrate_to_xdg(dry_run=False)
self.assertTrue(result, f"Failed with input: {yes_input}")
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,418 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
# Import after we set up environment to avoid module-level imports affecting tests
class TestConfigDirs(unittest.TestCase):
"""Test the config_dirs module with XDG compliance and backwards compatibility."""
def setUp(self):
"""Set up test environment."""
# Store original environment variables
self.original_env = {}
self.env_vars = [
"XDG_CONFIG_HOME",
"XDG_DATA_HOME",
"XDG_STATE_HOME",
"XDG_CACHE_HOME",
"LLAMA_STACK_CONFIG_DIR",
"SQLITE_STORE_DIR",
"FILES_STORAGE_DIR",
]
for key in self.env_vars:
self.original_env[key] = os.environ.get(key)
def tearDown(self):
"""Clean up test environment."""
# Restore original environment variables
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
# Clear module cache to ensure fresh imports
import sys
modules_to_clear = ["llama_stack.distribution.utils.config_dirs", "llama_stack.distribution.utils.xdg_utils"]
for module in modules_to_clear:
if module in sys.modules:
del sys.modules[module]
def clear_env_vars(self):
"""Clear all relevant environment variables."""
for key in self.env_vars:
os.environ.pop(key, None)
def test_config_dirs_xdg_defaults(self):
"""Test config_dirs with XDG default paths."""
self.clear_env_vars()
# Mock that no legacy directory exists
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = Path("/home/testuser")
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
# Import after setting up mocks
from llama_stack.distribution.utils.config_dirs import (
DEFAULT_CHECKPOINT_DIR,
DISTRIBS_BASE_DIR,
EXTERNAL_PROVIDERS_DIR,
LLAMA_STACK_CONFIG_DIR,
RUNTIME_BASE_DIR,
)
# Verify XDG-compliant paths
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/home/testuser/.config/llama-stack"))
self.assertEqual(DEFAULT_CHECKPOINT_DIR, Path("/home/testuser/.local/share/llama-stack/checkpoints"))
self.assertEqual(RUNTIME_BASE_DIR, Path("/home/testuser/.local/state/llama-stack/runtime"))
self.assertEqual(EXTERNAL_PROVIDERS_DIR, Path("/home/testuser/.config/llama-stack/providers.d"))
self.assertEqual(DISTRIBS_BASE_DIR, Path("/home/testuser/.config/llama-stack/distributions"))
def test_config_dirs_custom_xdg_paths(self):
"""Test config_dirs with custom XDG paths."""
self.clear_env_vars()
# Set custom XDG paths
os.environ["XDG_CONFIG_HOME"] = "/custom/config"
os.environ["XDG_DATA_HOME"] = "/custom/data"
os.environ["XDG_STATE_HOME"] = "/custom/state"
# Mock that no legacy directory exists
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
from llama_stack.distribution.utils.config_dirs import (
DEFAULT_CHECKPOINT_DIR,
DISTRIBS_BASE_DIR,
EXTERNAL_PROVIDERS_DIR,
LLAMA_STACK_CONFIG_DIR,
RUNTIME_BASE_DIR,
)
# Verify custom XDG paths are used
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/custom/config/llama-stack"))
self.assertEqual(DEFAULT_CHECKPOINT_DIR, Path("/custom/data/llama-stack/checkpoints"))
self.assertEqual(RUNTIME_BASE_DIR, Path("/custom/state/llama-stack/runtime"))
self.assertEqual(EXTERNAL_PROVIDERS_DIR, Path("/custom/config/llama-stack/providers.d"))
self.assertEqual(DISTRIBS_BASE_DIR, Path("/custom/config/llama-stack/distributions"))
def test_config_dirs_legacy_environment_variable(self):
"""Test config_dirs with legacy LLAMA_STACK_CONFIG_DIR."""
self.clear_env_vars()
# Set legacy environment variable
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/legacy/llama"
from llama_stack.distribution.utils.config_dirs import (
DEFAULT_CHECKPOINT_DIR,
DISTRIBS_BASE_DIR,
EXTERNAL_PROVIDERS_DIR,
LLAMA_STACK_CONFIG_DIR,
RUNTIME_BASE_DIR,
)
# All paths should use the legacy base
legacy_base = Path("/legacy/llama")
self.assertEqual(LLAMA_STACK_CONFIG_DIR, legacy_base)
self.assertEqual(DEFAULT_CHECKPOINT_DIR, legacy_base / "checkpoints")
self.assertEqual(RUNTIME_BASE_DIR, legacy_base / "runtime")
self.assertEqual(EXTERNAL_PROVIDERS_DIR, legacy_base / "providers.d")
self.assertEqual(DISTRIBS_BASE_DIR, legacy_base / "distributions")
def test_config_dirs_legacy_directory_exists(self):
"""Test config_dirs when legacy ~/.llama directory exists."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
home_dir = Path(temp_dir)
legacy_dir = home_dir / ".llama"
legacy_dir.mkdir()
(legacy_dir / "test_file").touch() # Add content
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = home_dir
from llama_stack.distribution.utils.config_dirs import (
DEFAULT_CHECKPOINT_DIR,
DISTRIBS_BASE_DIR,
EXTERNAL_PROVIDERS_DIR,
LLAMA_STACK_CONFIG_DIR,
RUNTIME_BASE_DIR,
)
# Should use legacy directory
self.assertEqual(LLAMA_STACK_CONFIG_DIR, legacy_dir)
self.assertEqual(DEFAULT_CHECKPOINT_DIR, legacy_dir / "checkpoints")
self.assertEqual(RUNTIME_BASE_DIR, legacy_dir / "runtime")
self.assertEqual(EXTERNAL_PROVIDERS_DIR, legacy_dir / "providers.d")
self.assertEqual(DISTRIBS_BASE_DIR, legacy_dir / "distributions")
def test_config_dirs_precedence_order(self):
"""Test precedence order: LLAMA_STACK_CONFIG_DIR > legacy directory > XDG."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
home_dir = Path(temp_dir)
legacy_dir = home_dir / ".llama"
legacy_dir.mkdir()
(legacy_dir / "test_file").touch()
# Set both legacy env var and XDG vars
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/priority/path"
os.environ["XDG_CONFIG_HOME"] = "/custom/config"
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = home_dir
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR
# Legacy env var should take precedence
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/priority/path"))
def test_config_dirs_all_path_types(self):
"""Test that all path objects are of correct type and absolute."""
self.clear_env_vars()
from llama_stack.distribution.utils.config_dirs import (
DEFAULT_CHECKPOINT_DIR,
DISTRIBS_BASE_DIR,
EXTERNAL_PROVIDERS_DIR,
LLAMA_STACK_CONFIG_DIR,
RUNTIME_BASE_DIR,
)
# All should be Path objects
paths = [
LLAMA_STACK_CONFIG_DIR,
DEFAULT_CHECKPOINT_DIR,
RUNTIME_BASE_DIR,
EXTERNAL_PROVIDERS_DIR,
DISTRIBS_BASE_DIR,
]
for path in paths:
self.assertIsInstance(path, Path, f"Path {path} should be Path object")
self.assertTrue(path.is_absolute(), f"Path {path} should be absolute")
def test_config_dirs_directory_relationships(self):
"""Test relationships between different directory paths."""
self.clear_env_vars()
from llama_stack.distribution.utils.config_dirs import (
DISTRIBS_BASE_DIR,
EXTERNAL_PROVIDERS_DIR,
LLAMA_STACK_CONFIG_DIR,
)
# Test parent-child relationships
self.assertEqual(EXTERNAL_PROVIDERS_DIR.parent, LLAMA_STACK_CONFIG_DIR)
self.assertEqual(DISTRIBS_BASE_DIR.parent, LLAMA_STACK_CONFIG_DIR)
# Test expected subdirectory names
self.assertEqual(EXTERNAL_PROVIDERS_DIR.name, "providers.d")
self.assertEqual(DISTRIBS_BASE_DIR.name, "distributions")
def test_config_dirs_environment_isolation(self):
"""Test that config_dirs is properly isolated between tests."""
self.clear_env_vars()
# First import with one set of environment variables
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/first/path"
# Clear module cache
import sys
if "llama_stack.distribution.utils.config_dirs" in sys.modules:
del sys.modules["llama_stack.distribution.utils.config_dirs"]
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as FIRST_CONFIG
# Change environment and re-import
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/second/path"
# Clear module cache again
if "llama_stack.distribution.utils.config_dirs" in sys.modules:
del sys.modules["llama_stack.distribution.utils.config_dirs"]
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as SECOND_CONFIG
# Should get different paths
self.assertEqual(FIRST_CONFIG, Path("/first/path"))
self.assertEqual(SECOND_CONFIG, Path("/second/path"))
def test_config_dirs_with_tilde_expansion(self):
"""Test config_dirs with tilde in paths."""
self.clear_env_vars()
os.environ["LLAMA_STACK_CONFIG_DIR"] = "~/custom_llama"
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR
# Should expand tilde
expected = Path.home() / "custom_llama"
self.assertEqual(LLAMA_STACK_CONFIG_DIR, expected)
def test_config_dirs_empty_environment_variables(self):
"""Test config_dirs with empty environment variables."""
self.clear_env_vars()
# Set empty values
os.environ["XDG_CONFIG_HOME"] = ""
os.environ["XDG_DATA_HOME"] = ""
# Mock no legacy directory
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = Path("/home/testuser")
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
from llama_stack.distribution.utils.config_dirs import (
DEFAULT_CHECKPOINT_DIR,
LLAMA_STACK_CONFIG_DIR,
)
# Should fall back to defaults
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/home/testuser/.config/llama-stack"))
self.assertEqual(DEFAULT_CHECKPOINT_DIR, Path("/home/testuser/.local/share/llama-stack/checkpoints"))
def test_config_dirs_relative_paths(self):
"""Test config_dirs with relative paths in environment variables."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
# Use relative path
os.environ["LLAMA_STACK_CONFIG_DIR"] = "relative/config"
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR
# Should be resolved to absolute path
self.assertTrue(LLAMA_STACK_CONFIG_DIR.is_absolute())
self.assertTrue(str(LLAMA_STACK_CONFIG_DIR).endswith("relative/config"))
def test_config_dirs_with_spaces_in_paths(self):
"""Test config_dirs with spaces in directory paths."""
self.clear_env_vars()
path_with_spaces = "/path with spaces/llama config"
os.environ["LLAMA_STACK_CONFIG_DIR"] = path_with_spaces
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path(path_with_spaces))
def test_config_dirs_unicode_paths(self):
"""Test config_dirs with unicode characters in paths."""
self.clear_env_vars()
unicode_path = "/配置/llama-stack"
os.environ["LLAMA_STACK_CONFIG_DIR"] = unicode_path
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path(unicode_path))
def test_config_dirs_compatibility_import(self):
"""Test that config_dirs can be imported without errors in various scenarios."""
self.clear_env_vars()
# Test import with no environment variables
try:
# If we get here without exception, the import succeeded
self.assertTrue(True)
except Exception as e:
self.fail(f"Import failed: {e}")
def test_config_dirs_multiple_imports(self):
"""Test that multiple imports of config_dirs return consistent results."""
self.clear_env_vars()
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/consistent/path"
# First import
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as FIRST_IMPORT
# Second import (should get cached result)
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as SECOND_IMPORT
self.assertEqual(FIRST_IMPORT, SECOND_IMPORT)
self.assertIs(FIRST_IMPORT, SECOND_IMPORT) # Should be the same object
class TestConfigDirsIntegration(unittest.TestCase):
"""Integration tests for config_dirs with other modules."""
def setUp(self):
"""Set up test environment."""
self.original_env = {}
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
self.original_env[key] = os.environ.get(key)
def tearDown(self):
"""Clean up test environment."""
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
# Clear module cache
import sys
modules_to_clear = ["llama_stack.distribution.utils.config_dirs", "llama_stack.distribution.utils.xdg_utils"]
for module in modules_to_clear:
if module in sys.modules:
del sys.modules[module]
def test_config_dirs_with_model_utils(self):
"""Test that config_dirs works correctly with model_utils."""
for key in self.original_env:
os.environ.pop(key, None)
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
# Test that model_local_dir uses the correct base directory
model_descriptor = "meta-llama/Llama-3.2-1B-Instruct"
expected_path = str(DEFAULT_CHECKPOINT_DIR / model_descriptor.replace(":", "-"))
actual_path = model_local_dir(model_descriptor)
self.assertEqual(actual_path, expected_path)
def test_config_dirs_consistency_across_modules(self):
"""Test that all modules use consistent directory paths."""
for key in self.original_env:
os.environ.pop(key, None)
from llama_stack.distribution.utils.config_dirs import (
DEFAULT_CHECKPOINT_DIR,
LLAMA_STACK_CONFIG_DIR,
RUNTIME_BASE_DIR,
)
from llama_stack.distribution.utils.xdg_utils import (
get_llama_stack_config_dir,
get_llama_stack_data_dir,
get_llama_stack_state_dir,
)
# Paths should be consistent between modules
self.assertEqual(LLAMA_STACK_CONFIG_DIR, get_llama_stack_config_dir())
self.assertEqual(DEFAULT_CHECKPOINT_DIR.parent, get_llama_stack_data_dir())
self.assertEqual(RUNTIME_BASE_DIR.parent, get_llama_stack_state_dir())
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,422 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
import yaml
# Template imports will be tested through file system access
class TestTemplateXDGPaths(unittest.TestCase):
"""Test that templates use XDG-compliant paths correctly."""
def setUp(self):
"""Set up test environment."""
self.original_env = {}
self.env_vars = [
"XDG_CONFIG_HOME",
"XDG_DATA_HOME",
"XDG_STATE_HOME",
"XDG_CACHE_HOME",
"LLAMA_STACK_CONFIG_DIR",
"SQLITE_STORE_DIR",
"FILES_STORAGE_DIR",
]
for key in self.env_vars:
self.original_env[key] = os.environ.get(key)
def tearDown(self):
"""Clean up test environment."""
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def clear_env_vars(self):
"""Clear all relevant environment variables."""
for key in self.env_vars:
os.environ.pop(key, None)
def test_ollama_template_run_yaml_xdg_paths(self):
"""Test that ollama template's run.yaml uses XDG environment variables."""
template_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "run.yaml"
if not template_path.exists():
self.skipTest("Ollama template not found")
content = template_path.read_text()
# Check for XDG-compliant environment variable references
self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}", content)
self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}", content)
# Check that paths use llama-stack directory
self.assertIn("llama-stack", content)
# Check specific path patterns
self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama", content)
self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama", content)
def test_ollama_template_run_yaml_parsing(self):
"""Test that ollama template's run.yaml can be parsed correctly."""
template_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "run.yaml"
if not template_path.exists():
self.skipTest("Ollama template not found")
content = template_path.read_text()
# Replace environment variables with test values for parsing
test_content = (
content.replace("${env.XDG_STATE_HOME:-~/.local/state}", "/test/state")
.replace("${env.XDG_DATA_HOME:-~/.local/share}", "/test/data")
.replace(
"${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}",
"/test/state/llama-stack/distributions/ollama",
)
)
# Should be valid YAML
try:
yaml.safe_load(test_content)
except yaml.YAMLError as e:
self.fail(f"Template YAML is invalid: {e}")
def test_template_environment_variable_expansion(self):
"""Test environment variable expansion in templates."""
self.clear_env_vars()
# Set XDG variables
os.environ["XDG_STATE_HOME"] = "/custom/state"
os.environ["XDG_DATA_HOME"] = "/custom/data"
# Test pattern that should expand
pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test"
expected = "/custom/state/llama-stack/test"
# Mock environment variable expansion (this would normally be done by the shell)
expanded = pattern.replace("${env.XDG_STATE_HOME:-~/.local/state}", os.environ["XDG_STATE_HOME"])
self.assertEqual(expanded, expected)
def test_template_fallback_values(self):
"""Test that templates have correct fallback values."""
self.clear_env_vars()
# Test fallback pattern
pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test"
# When environment variable is not set, should use fallback
if "XDG_STATE_HOME" not in os.environ:
# This is what the shell would do
expanded = pattern.replace("${env.XDG_STATE_HOME:-~/.local/state}", "~/.local/state")
expected = "~/.local/state/llama-stack/test"
self.assertEqual(expanded, expected)
def test_ollama_template_python_config_xdg(self):
"""Test that ollama template's Python config uses XDG-compliant paths."""
template_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "ollama.py"
if not template_path.exists():
self.skipTest("Ollama template Python file not found")
content = template_path.read_text()
# Check for XDG-compliant environment variable references
self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}", content)
self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}", content)
# Check that paths use llama-stack directory
self.assertIn("llama-stack", content)
def test_template_path_consistency(self):
"""Test that template paths are consistent across different files."""
ollama_yaml_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "run.yaml"
ollama_py_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "ollama.py"
if not ollama_yaml_path.exists() or not ollama_py_path.exists():
self.skipTest("Ollama template files not found")
yaml_content = ollama_yaml_path.read_text()
py_content = ollama_py_path.read_text()
# Both should use the same XDG environment variable patterns
xdg_patterns = ["${env.XDG_STATE_HOME:-~/.local/state}", "${env.XDG_DATA_HOME:-~/.local/share}", "llama-stack"]
for pattern in xdg_patterns:
self.assertIn(pattern, yaml_content, f"Pattern {pattern} not found in YAML")
self.assertIn(pattern, py_content, f"Pattern {pattern} not found in Python")
def test_template_no_hardcoded_legacy_paths(self):
"""Test that templates don't contain hardcoded legacy paths."""
template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates"
if not template_dir.exists():
self.skipTest("Templates directory not found")
# Check various template files
for template_path in template_dir.rglob("*.yaml"):
content = template_path.read_text()
# Should not contain hardcoded ~/.llama paths
self.assertNotIn("~/.llama", content, f"Found hardcoded ~/.llama in {template_path}")
# Should not contain hardcoded /tmp paths for persistent data
if "db_path" in content or "storage_dir" in content:
self.assertNotIn("/tmp", content, f"Found hardcoded /tmp in {template_path}")
def test_template_environment_variable_format(self):
"""Test that templates use correct environment variable format."""
template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates"
if not template_dir.exists():
self.skipTest("Templates directory not found")
# Pattern for XDG environment variables with fallbacks
xdg_patterns = [
"${env.XDG_CONFIG_HOME:-~/.config}",
"${env.XDG_DATA_HOME:-~/.local/share}",
"${env.XDG_STATE_HOME:-~/.local/state}",
"${env.XDG_CACHE_HOME:-~/.cache}",
]
for template_path in template_dir.rglob("*.yaml"):
content = template_path.read_text()
# If XDG variables are used, they should have proper fallbacks
for pattern in xdg_patterns:
base_var = pattern.split(":-")[0] + "}"
if base_var in content:
self.assertIn(pattern, content, f"XDG variable without fallback in {template_path}")
def test_template_sqlite_store_dir_xdg(self):
"""Test that SQLITE_STORE_DIR uses XDG-compliant fallback."""
template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates"
if not template_dir.exists():
self.skipTest("Templates directory not found")
for template_path in template_dir.rglob("*.yaml"):
content = template_path.read_text()
if "SQLITE_STORE_DIR" in content:
# Should use XDG fallback pattern
self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}", content)
self.assertIn("llama-stack", content)
def test_template_files_storage_dir_xdg(self):
"""Test that FILES_STORAGE_DIR uses XDG-compliant fallback."""
template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates"
if not template_dir.exists():
self.skipTest("Templates directory not found")
for template_path in template_dir.rglob("*.yaml"):
content = template_path.read_text()
if "FILES_STORAGE_DIR" in content:
# Should use XDG fallback pattern
self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}", content)
self.assertIn("llama-stack", content)
class TestTemplateCodeGeneration(unittest.TestCase):
"""Test template code generation with XDG paths."""
def setUp(self):
"""Set up test environment."""
self.original_env = {}
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
self.original_env[key] = os.environ.get(key)
def tearDown(self):
"""Clean up test environment."""
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def test_provider_codegen_xdg_paths(self):
"""Test that provider code generation uses XDG-compliant paths."""
codegen_path = Path(__file__).parent.parent.parent / "scripts" / "provider_codegen.py"
if not codegen_path.exists():
self.skipTest("Provider codegen script not found")
content = codegen_path.read_text()
# Should use XDG-compliant path in documentation
self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}/llama-stack", content)
# Should not use hardcoded ~/.llama paths
self.assertNotIn("~/.llama/dummy", content)
def test_template_sample_config_paths(self):
"""Test that template sample configs use XDG-compliant paths."""
# This test checks that when templates generate sample configs,
# they use XDG-compliant paths
# Mock a template that generates sample config
with patch("llama_stack.templates.template.Template") as mock_template:
mock_instance = MagicMock()
mock_template.return_value = mock_instance
# Mock sample config generation
def mock_sample_config(distro_dir):
# Should use XDG-compliant path structure
self.assertIn("llama-stack", distro_dir)
return {"config": "test"}
mock_instance.sample_run_config = mock_sample_config
# Test sample config generation
template = mock_template()
template.sample_run_config("${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/test")
def test_template_path_substitution(self):
"""Test that template path substitution works correctly."""
# Test path substitution in template generation
original_path = "~/.llama/distributions/test"
# Should be converted to XDG-compliant path
xdg_path = original_path.replace("~/.llama", "${env.XDG_DATA_HOME:-~/.local/share}/llama-stack")
expected = "${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/test"
self.assertEqual(xdg_path, expected)
def test_template_environment_variable_precedence(self):
"""Test environment variable precedence in templates."""
# Test that custom XDG variables take precedence over defaults
test_cases = [
{
"env": {"XDG_STATE_HOME": "/custom/state"},
"pattern": "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test",
"expected": "/custom/state/llama-stack/test",
},
{
"env": {}, # No XDG variable set
"pattern": "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test",
"expected": "~/.local/state/llama-stack/test",
},
]
for case in test_cases:
# Clear environment
for key in ["XDG_STATE_HOME", "XDG_DATA_HOME", "XDG_CONFIG_HOME"]:
os.environ.pop(key, None)
# Set test environment
for key, value in case["env"].items():
os.environ[key] = value
# Simulate shell variable expansion
pattern = case["pattern"]
for key, value in case["env"].items():
var_pattern = f"${{env.{key}:-"
if var_pattern in pattern:
# Replace with actual value
pattern = pattern.replace(f"${{env.{key}:-~/.local/state}}", value)
# If no replacement happened, use fallback
if "${env.XDG_STATE_HOME:-~/.local/state}" in pattern:
pattern = pattern.replace("${env.XDG_STATE_HOME:-~/.local/state}", "~/.local/state")
self.assertEqual(pattern, case["expected"])
class TestTemplateIntegration(unittest.TestCase):
"""Integration tests for templates with XDG compliance."""
def setUp(self):
"""Set up test environment."""
self.original_env = {}
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
self.original_env[key] = os.environ.get(key)
def tearDown(self):
"""Clean up test environment."""
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def test_template_with_xdg_environment(self):
"""Test template behavior with XDG environment variables set."""
# Clear environment
for key in self.original_env:
os.environ.pop(key, None)
# Set custom XDG variables
os.environ["XDG_CONFIG_HOME"] = "/custom/config"
os.environ["XDG_DATA_HOME"] = "/custom/data"
os.environ["XDG_STATE_HOME"] = "/custom/state"
# Test that template paths would resolve correctly
# (This is a conceptual test since actual shell expansion happens at runtime)
template_pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test"
# In a real shell, this would expand to:
# Verify the pattern structure is correct
self.assertIn("XDG_STATE_HOME", template_pattern)
self.assertIn("llama-stack", template_pattern)
self.assertIn("~/.local/state", template_pattern) # fallback
def test_template_with_no_xdg_environment(self):
"""Test template behavior with no XDG environment variables."""
# Clear all XDG environment variables
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "XDG_CACHE_HOME"]:
os.environ.pop(key, None)
# Test that templates would use fallback values
template_pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test"
# In a real shell with no XDG_STATE_HOME, this would expand to:
# Verify the pattern structure includes fallback
self.assertIn(":-~/.local/state", template_pattern)
def test_template_consistency_across_providers(self):
"""Test that all template providers use consistent XDG patterns."""
templates_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates"
if not templates_dir.exists():
self.skipTest("Templates directory not found")
# Expected XDG patterns that should be consistent across templates
# Check a few different provider templates
provider_templates = []
for provider_dir in templates_dir.iterdir():
if provider_dir.is_dir() and not provider_dir.name.startswith("."):
run_yaml = provider_dir / "run.yaml"
if run_yaml.exists():
provider_templates.append(run_yaml)
if not provider_templates:
self.skipTest("No provider templates found")
# Check that templates use consistent patterns
for template_path in provider_templates[:3]: # Check first 3 templates
content = template_path.read_text()
# Should use llama-stack in paths
if any(xdg_var in content for xdg_var in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME"]):
self.assertIn("llama-stack", content, f"Template {template_path} uses XDG but not llama-stack")
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,610 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from llama_stack.distribution.utils.xdg_utils import (
ensure_directory_exists,
get_llama_stack_cache_dir,
get_llama_stack_config_dir,
get_llama_stack_data_dir,
get_llama_stack_state_dir,
get_xdg_cache_home,
get_xdg_compliant_path,
get_xdg_config_home,
get_xdg_data_home,
get_xdg_state_home,
migrate_legacy_directory,
)
class TestXDGCompliance(unittest.TestCase):
"""Comprehensive test suite for XDG Base Directory Specification compliance."""
def setUp(self):
"""Set up test environment."""
# Store original environment variables
self.original_env = {}
self.xdg_vars = ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_CACHE_HOME", "XDG_STATE_HOME"]
self.llama_vars = ["LLAMA_STACK_CONFIG_DIR", "SQLITE_STORE_DIR", "FILES_STORAGE_DIR"]
for key in self.xdg_vars + self.llama_vars:
self.original_env[key] = os.environ.get(key)
def tearDown(self):
"""Clean up test environment."""
# Restore original environment variables
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def clear_env_vars(self, vars_to_clear=None):
"""Helper to clear environment variables."""
if vars_to_clear is None:
vars_to_clear = self.xdg_vars + self.llama_vars
for key in vars_to_clear:
os.environ.pop(key, None)
def test_xdg_defaults(self):
"""Test that XDG directories use correct defaults when no env vars are set."""
self.clear_env_vars()
home = Path.home()
self.assertEqual(get_xdg_config_home(), home / ".config")
self.assertEqual(get_xdg_data_home(), home / ".local" / "share")
self.assertEqual(get_xdg_cache_home(), home / ".cache")
self.assertEqual(get_xdg_state_home(), home / ".local" / "state")
def test_xdg_custom_paths(self):
"""Test that custom XDG paths are respected."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
os.environ["XDG_CONFIG_HOME"] = str(temp_path / "config")
os.environ["XDG_DATA_HOME"] = str(temp_path / "data")
os.environ["XDG_CACHE_HOME"] = str(temp_path / "cache")
os.environ["XDG_STATE_HOME"] = str(temp_path / "state")
self.assertEqual(get_xdg_config_home(), temp_path / "config")
self.assertEqual(get_xdg_data_home(), temp_path / "data")
self.assertEqual(get_xdg_cache_home(), temp_path / "cache")
self.assertEqual(get_xdg_state_home(), temp_path / "state")
def test_xdg_paths_with_tilde(self):
"""Test XDG paths that use tilde expansion."""
os.environ["XDG_CONFIG_HOME"] = "~/custom_config"
os.environ["XDG_DATA_HOME"] = "~/custom_data"
home = Path.home()
self.assertEqual(get_xdg_config_home(), home / "custom_config")
self.assertEqual(get_xdg_data_home(), home / "custom_data")
def test_xdg_paths_relative(self):
"""Test XDG paths with relative paths get resolved."""
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
os.environ["XDG_CONFIG_HOME"] = "relative_config"
# Should resolve relative to current directory
result = get_xdg_config_home()
self.assertTrue(result.is_absolute())
self.assertTrue(str(result).endswith("relative_config"))
def test_llama_stack_directories_new_installation(self):
"""Test llama-stack directories for new installations (no legacy directory)."""
self.clear_env_vars()
home = Path.home()
# Mock that ~/.llama doesn't exist
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = home
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
self.assertEqual(get_llama_stack_config_dir(), home / ".config" / "llama-stack")
self.assertEqual(get_llama_stack_data_dir(), home / ".local" / "share" / "llama-stack")
self.assertEqual(get_llama_stack_state_dir(), home / ".local" / "state" / "llama-stack")
self.assertEqual(get_llama_stack_cache_dir(), home / ".cache" / "llama-stack")
def test_llama_stack_directories_with_custom_xdg(self):
"""Test llama-stack directories with custom XDG paths."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
os.environ["XDG_CONFIG_HOME"] = str(temp_path / "config")
os.environ["XDG_DATA_HOME"] = str(temp_path / "data")
os.environ["XDG_STATE_HOME"] = str(temp_path / "state")
os.environ["XDG_CACHE_HOME"] = str(temp_path / "cache")
# Mock that ~/.llama doesn't exist
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
self.assertEqual(get_llama_stack_config_dir(), temp_path / "config" / "llama-stack")
self.assertEqual(get_llama_stack_data_dir(), temp_path / "data" / "llama-stack")
self.assertEqual(get_llama_stack_state_dir(), temp_path / "state" / "llama-stack")
self.assertEqual(get_llama_stack_cache_dir(), temp_path / "cache" / "llama-stack")
def test_legacy_environment_variable_precedence(self):
"""Test that LLAMA_STACK_CONFIG_DIR takes highest precedence."""
with tempfile.TemporaryDirectory() as temp_dir:
legacy_path = Path(temp_dir) / "legacy"
xdg_path = Path(temp_dir) / "xdg"
# Set both legacy and XDG variables
os.environ["LLAMA_STACK_CONFIG_DIR"] = str(legacy_path)
os.environ["XDG_CONFIG_HOME"] = str(xdg_path / "config")
os.environ["XDG_DATA_HOME"] = str(xdg_path / "data")
os.environ["XDG_STATE_HOME"] = str(xdg_path / "state")
# Legacy should take precedence for all directory types
self.assertEqual(get_llama_stack_config_dir(), legacy_path)
self.assertEqual(get_llama_stack_data_dir(), legacy_path)
self.assertEqual(get_llama_stack_state_dir(), legacy_path)
self.assertEqual(get_llama_stack_cache_dir(), legacy_path)
def test_legacy_directory_exists_and_has_content(self):
"""Test that existing ~/.llama directory with content is used."""
with tempfile.TemporaryDirectory() as temp_dir:
home = Path(temp_dir)
legacy_llama = home / ".llama"
legacy_llama.mkdir()
# Create some content to simulate existing data
(legacy_llama / "test_file").touch()
(legacy_llama / "distributions").mkdir()
# Clear environment variables
self.clear_env_vars()
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = home
self.assertEqual(get_llama_stack_config_dir(), legacy_llama)
self.assertEqual(get_llama_stack_data_dir(), legacy_llama)
self.assertEqual(get_llama_stack_state_dir(), legacy_llama)
def test_legacy_directory_exists_but_empty(self):
"""Test that empty ~/.llama directory is ignored in favor of XDG."""
with tempfile.TemporaryDirectory() as temp_dir:
home = Path(temp_dir)
legacy_llama = home / ".llama"
legacy_llama.mkdir()
# Don't add any content - directory is empty
self.clear_env_vars()
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = home
# Should use XDG paths since legacy directory is empty
self.assertEqual(get_llama_stack_config_dir(), home / ".config" / "llama-stack")
self.assertEqual(get_llama_stack_data_dir(), home / ".local" / "share" / "llama-stack")
self.assertEqual(get_llama_stack_state_dir(), home / ".local" / "state" / "llama-stack")
def test_xdg_compliant_path_function(self):
"""Test the get_xdg_compliant_path utility function."""
self.clear_env_vars()
home = Path.home()
# Mock that ~/.llama doesn't exist
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = home
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
self.assertEqual(get_xdg_compliant_path("config"), home / ".config" / "llama-stack")
self.assertEqual(
get_xdg_compliant_path("data", "models"), home / ".local" / "share" / "llama-stack" / "models"
)
self.assertEqual(
get_xdg_compliant_path("state", "runtime"), home / ".local" / "state" / "llama-stack" / "runtime"
)
self.assertEqual(get_xdg_compliant_path("cache", "temp"), home / ".cache" / "llama-stack" / "temp")
def test_xdg_compliant_path_invalid_type(self):
"""Test that invalid path types raise ValueError."""
with self.assertRaises(ValueError) as context:
get_xdg_compliant_path("invalid_type")
self.assertIn("Unknown path type", str(context.exception))
self.assertIn("invalid_type", str(context.exception))
def test_xdg_compliant_path_with_subdirectory(self):
"""Test get_xdg_compliant_path with various subdirectories."""
self.clear_env_vars()
home = Path.home()
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = home
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
# Test nested subdirectories
self.assertEqual(
get_xdg_compliant_path("data", "models/checkpoints"),
home / ".local" / "share" / "llama-stack" / "models/checkpoints",
)
# Test with Path object
self.assertEqual(
get_xdg_compliant_path("config", str(Path("distributions") / "ollama")),
home / ".config" / "llama-stack" / "distributions" / "ollama",
)
def test_ensure_directory_exists(self):
"""Test the ensure_directory_exists utility function."""
with tempfile.TemporaryDirectory() as temp_dir:
test_path = Path(temp_dir) / "nested" / "directory" / "structure"
# Directory shouldn't exist initially
self.assertFalse(test_path.exists())
# Create it
ensure_directory_exists(test_path)
# Should exist now
self.assertTrue(test_path.exists())
self.assertTrue(test_path.is_dir())
def test_ensure_directory_exists_already_exists(self):
"""Test ensure_directory_exists when directory already exists."""
with tempfile.TemporaryDirectory() as temp_dir:
test_path = Path(temp_dir) / "existing"
test_path.mkdir()
# Should not raise an error
ensure_directory_exists(test_path)
self.assertTrue(test_path.exists())
def test_config_dirs_import_and_types(self):
"""Test that the config_dirs module imports correctly and has proper types."""
from llama_stack.distribution.utils.config_dirs import (
DEFAULT_CHECKPOINT_DIR,
DISTRIBS_BASE_DIR,
EXTERNAL_PROVIDERS_DIR,
LLAMA_STACK_CONFIG_DIR,
RUNTIME_BASE_DIR,
)
# All should be Path objects
self.assertIsInstance(LLAMA_STACK_CONFIG_DIR, Path)
self.assertIsInstance(DEFAULT_CHECKPOINT_DIR, Path)
self.assertIsInstance(RUNTIME_BASE_DIR, Path)
self.assertIsInstance(EXTERNAL_PROVIDERS_DIR, Path)
self.assertIsInstance(DISTRIBS_BASE_DIR, Path)
# All should be absolute paths
self.assertTrue(LLAMA_STACK_CONFIG_DIR.is_absolute())
self.assertTrue(DEFAULT_CHECKPOINT_DIR.is_absolute())
self.assertTrue(RUNTIME_BASE_DIR.is_absolute())
self.assertTrue(EXTERNAL_PROVIDERS_DIR.is_absolute())
self.assertTrue(DISTRIBS_BASE_DIR.is_absolute())
def test_config_dirs_proper_structure(self):
"""Test that config_dirs uses proper XDG structure."""
from llama_stack.distribution.utils.config_dirs import (
DISTRIBS_BASE_DIR,
EXTERNAL_PROVIDERS_DIR,
LLAMA_STACK_CONFIG_DIR,
)
# Check that paths contain expected components
config_str = str(LLAMA_STACK_CONFIG_DIR)
self.assertTrue(
"llama-stack" in config_str or ".llama" in config_str,
f"Config dir should contain 'llama-stack' or '.llama': {config_str}",
)
# Test relationships between directories
self.assertEqual(DISTRIBS_BASE_DIR, LLAMA_STACK_CONFIG_DIR / "distributions")
self.assertEqual(EXTERNAL_PROVIDERS_DIR, LLAMA_STACK_CONFIG_DIR / "providers.d")
def test_environment_variable_combinations(self):
"""Test various combinations of environment variables."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Test partial XDG variables
os.environ["XDG_CONFIG_HOME"] = str(temp_path / "config")
# Leave others as default
self.clear_env_vars(["XDG_DATA_HOME", "XDG_STATE_HOME", "XDG_CACHE_HOME"])
home = Path.home()
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
self.assertEqual(get_llama_stack_config_dir(), temp_path / "config" / "llama-stack")
self.assertEqual(get_llama_stack_data_dir(), home / ".local" / "share" / "llama-stack")
self.assertEqual(get_llama_stack_state_dir(), home / ".local" / "state" / "llama-stack")
def test_migrate_legacy_directory_no_legacy(self):
"""Test migration when no legacy directory exists."""
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = Path("/fake/home")
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
# Should return True (success) when no migration needed
result = migrate_legacy_directory()
self.assertTrue(result)
def test_migrate_legacy_directory_exists(self):
"""Test migration message when legacy directory exists."""
with tempfile.TemporaryDirectory() as temp_dir:
home = Path(temp_dir)
legacy_llama = home / ".llama"
legacy_llama.mkdir()
(legacy_llama / "test_file").touch()
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = home
with patch("builtins.print") as mock_print:
result = migrate_legacy_directory()
self.assertTrue(result)
# Check that migration information was printed
print_calls = [call[0][0] for call in mock_print.call_args_list]
self.assertTrue(any("Found legacy directory" in call for call in print_calls))
self.assertTrue(any("Consider migrating" in call for call in print_calls))
def test_path_consistency_across_functions(self):
"""Test that all path functions return consistent results."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
home = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = home
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
# All config-related functions should return the same base
config_dir = get_llama_stack_config_dir()
config_path = get_xdg_compliant_path("config")
self.assertEqual(config_dir, config_path)
# All data-related functions should return the same base
data_dir = get_llama_stack_data_dir()
data_path = get_xdg_compliant_path("data")
self.assertEqual(data_dir, data_path)
def test_unicode_and_special_characters(self):
"""Test XDG paths with unicode and special characters."""
with tempfile.TemporaryDirectory() as temp_dir:
# Test with unicode characters
unicode_path = Path(temp_dir) / "配置" / "llama-stack"
os.environ["XDG_CONFIG_HOME"] = str(unicode_path.parent)
result = get_xdg_config_home()
self.assertEqual(result, unicode_path.parent)
# Test spaces in paths
space_path = Path(temp_dir) / "my config"
os.environ["XDG_CONFIG_HOME"] = str(space_path)
result = get_xdg_config_home()
self.assertEqual(result, space_path)
def test_concurrent_access_safety(self):
"""Test that XDG functions are safe for concurrent access."""
import threading
import time
results = []
errors = []
def worker():
try:
# Simulate concurrent access
config_dir = get_llama_stack_config_dir()
time.sleep(0.01) # Small delay to increase chance of race conditions
data_dir = get_llama_stack_data_dir()
results.append((config_dir, data_dir))
except Exception as e:
errors.append(e)
# Start multiple threads
threads = []
for _ in range(10):
t = threading.Thread(target=worker)
threads.append(t)
t.start()
# Wait for all threads
for t in threads:
t.join()
# Check results
self.assertEqual(len(errors), 0, f"Concurrent access errors: {errors}")
self.assertEqual(len(results), 10)
# All results should be identical
first_result = results[0]
for result in results[1:]:
self.assertEqual(result, first_result)
def test_symlink_handling(self):
"""Test XDG path handling with symbolic links."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create actual directory
actual_dir = temp_path / "actual_config"
actual_dir.mkdir()
# Create symlink
symlink_dir = temp_path / "symlinked_config"
symlink_dir.symlink_to(actual_dir)
os.environ["XDG_CONFIG_HOME"] = str(symlink_dir)
result = get_xdg_config_home()
self.assertEqual(result, symlink_dir)
# Should resolve to actual path when needed
self.assertTrue(result.exists())
def test_readonly_directory_handling(self):
"""Test behavior when XDG directories are read-only."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
readonly_dir = temp_path / "readonly"
readonly_dir.mkdir()
# Make directory read-only
readonly_dir.chmod(0o444)
try:
os.environ["XDG_CONFIG_HOME"] = str(readonly_dir)
# Should still return the path even if read-only
result = get_xdg_config_home()
self.assertEqual(result, readonly_dir)
finally:
# Restore permissions for cleanup
readonly_dir.chmod(0o755)
def test_nonexistent_parent_directory(self):
"""Test XDG paths with non-existent parent directories."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use a path with non-existent parents
nonexistent_path = Path(temp_dir) / "does" / "not" / "exist" / "config"
os.environ["XDG_CONFIG_HOME"] = str(nonexistent_path)
# Should return the path even if it doesn't exist
result = get_xdg_config_home()
self.assertEqual(result, nonexistent_path)
def test_env_var_expansion(self):
"""Test environment variable expansion in XDG paths."""
with tempfile.TemporaryDirectory() as temp_dir:
os.environ["TEST_BASE"] = temp_dir
os.environ["XDG_CONFIG_HOME"] = "$TEST_BASE/config"
# Path expansion should work
result = get_xdg_config_home()
expected = Path(temp_dir) / "config"
self.assertEqual(result, expected)
class TestXDGEdgeCases(unittest.TestCase):
"""Test edge cases and error conditions for XDG compliance."""
def setUp(self):
"""Set up test environment."""
self.original_env = {}
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_CACHE_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
self.original_env[key] = os.environ.get(key)
def tearDown(self):
"""Clean up test environment."""
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def test_empty_environment_variables(self):
"""Test behavior with empty environment variables."""
# Set empty values
os.environ["XDG_CONFIG_HOME"] = ""
os.environ["XDG_DATA_HOME"] = ""
# Should fall back to defaults
home = Path.home()
self.assertEqual(get_xdg_config_home(), home / ".config")
self.assertEqual(get_xdg_data_home(), home / ".local" / "share")
def test_whitespace_only_environment_variables(self):
"""Test behavior with whitespace-only environment variables."""
os.environ["XDG_CONFIG_HOME"] = " "
os.environ["XDG_DATA_HOME"] = "\t\n"
# Should handle whitespace gracefully
result_config = get_xdg_config_home()
result_data = get_xdg_data_home()
# Results should be valid Path objects
self.assertIsInstance(result_config, Path)
self.assertIsInstance(result_data, Path)
def test_very_long_paths(self):
"""Test behavior with very long directory paths."""
with tempfile.TemporaryDirectory() as temp_dir:
# Create a very long path
long_path_parts = ["very_long_directory_name_" + str(i) for i in range(20)]
long_path = Path(temp_dir)
for part in long_path_parts:
long_path = long_path / part
os.environ["XDG_CONFIG_HOME"] = str(long_path)
result = get_xdg_config_home()
self.assertEqual(result, long_path)
def test_circular_symlinks(self):
"""Test handling of circular symbolic links."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create circular symlinks
link1 = temp_path / "link1"
link2 = temp_path / "link2"
try:
link1.symlink_to(link2)
link2.symlink_to(link1)
os.environ["XDG_CONFIG_HOME"] = str(link1)
# Should handle circular symlinks gracefully
result = get_xdg_config_home()
self.assertEqual(result, link1)
except (OSError, NotImplementedError):
# Some systems don't support circular symlinks
self.skipTest("System doesn't support circular symlinks")
def test_permission_denied_scenarios(self):
"""Test scenarios where permission is denied."""
# This test is platform-specific and may not work on all systems
try:
# Try to use a system directory that typically requires root
os.environ["XDG_CONFIG_HOME"] = "/root/.config"
# Should still return the path even if we can't access it
result = get_xdg_config_home()
self.assertEqual(result, Path("/root/.config"))
except Exception:
# If this fails, it's not critical for the XDG implementation
pass
def test_network_paths(self):
"""Test XDG paths with network/UNC paths (Windows-style)."""
# Test UNC path (though this may not work on non-Windows systems)
network_path = "//server/share/config"
os.environ["XDG_CONFIG_HOME"] = network_path
result = get_xdg_config_home()
# Should handle network paths gracefully
self.assertIsInstance(result, Path)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,522 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import platform
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from llama_stack.distribution.utils.xdg_utils import (
ensure_directory_exists,
get_llama_stack_config_dir,
get_xdg_config_home,
migrate_legacy_directory,
)
class TestXDGEdgeCases(unittest.TestCase):
"""Test edge cases and error conditions for XDG compliance."""
def setUp(self):
"""Set up test environment."""
self.original_env = {}
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "XDG_CACHE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
self.original_env[key] = os.environ.get(key)
def tearDown(self):
"""Clean up test environment."""
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
def clear_env_vars(self):
"""Clear all XDG environment variables."""
for key in self.original_env:
os.environ.pop(key, None)
def test_very_long_paths(self):
"""Test XDG functions with very long directory paths."""
self.clear_env_vars()
# Create a very long path (close to filesystem limits)
long_components = ["very_long_directory_name_" + str(i) for i in range(50)]
long_path = "/tmp/" + "/".join(long_components)
# Test with very long XDG paths
os.environ["XDG_CONFIG_HOME"] = long_path
result = get_xdg_config_home()
self.assertEqual(result, Path(long_path))
# Should handle long paths in llama-stack functions
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
config_dir = get_llama_stack_config_dir()
self.assertEqual(config_dir, Path(long_path) / "llama-stack")
def test_paths_with_special_characters(self):
"""Test XDG functions with special characters in paths."""
self.clear_env_vars()
# Test various special characters
special_chars = [
"path with spaces",
"path-with-hyphens",
"path_with_underscores",
"path.with.dots",
"path@with@symbols",
"path+with+plus",
"path&with&ampersand",
"path(with)parentheses",
]
for special_path in special_chars:
with self.subTest(path=special_path):
test_path = f"/tmp/{special_path}"
os.environ["XDG_CONFIG_HOME"] = test_path
result = get_xdg_config_home()
self.assertEqual(result, Path(test_path))
def test_unicode_paths(self):
"""Test XDG functions with unicode characters in paths."""
self.clear_env_vars()
unicode_paths = [
"/配置/llama-stack", # Chinese
"/конфигурация/llama-stack", # Russian
"/構成/llama-stack", # Japanese
"/구성/llama-stack", # Korean
"/تكوين/llama-stack", # Arabic
"/configuración/llama-stack", # Spanish with accents
"/配置📁/llama-stack", # With emoji
]
for unicode_path in unicode_paths:
with self.subTest(path=unicode_path):
os.environ["XDG_CONFIG_HOME"] = unicode_path
result = get_xdg_config_home()
self.assertEqual(result, Path(unicode_path))
def test_network_paths(self):
"""Test XDG functions with network/UNC paths."""
self.clear_env_vars()
if platform.system() == "Windows":
# Test Windows UNC paths
unc_paths = [
"\\\\server\\share\\config",
"\\\\server.domain.com\\share\\config",
"\\\\192.168.1.100\\config",
]
for unc_path in unc_paths:
with self.subTest(path=unc_path):
os.environ["XDG_CONFIG_HOME"] = unc_path
result = get_xdg_config_home()
self.assertEqual(result, Path(unc_path))
else:
# Test network mount paths on Unix-like systems
network_paths = [
"/mnt/nfs/config",
"/net/server/config",
"/media/network/config",
]
for network_path in network_paths:
with self.subTest(path=network_path):
os.environ["XDG_CONFIG_HOME"] = network_path
result = get_xdg_config_home()
self.assertEqual(result, Path(network_path))
def test_nonexistent_paths(self):
"""Test XDG functions with non-existent paths."""
self.clear_env_vars()
nonexistent_path = "/this/path/does/not/exist/config"
os.environ["XDG_CONFIG_HOME"] = nonexistent_path
# Should return the path even if it doesn't exist
result = get_xdg_config_home()
self.assertEqual(result, Path(nonexistent_path))
# Should work with llama-stack functions too
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
mock_exists.return_value = False
config_dir = get_llama_stack_config_dir()
self.assertEqual(config_dir, Path(nonexistent_path) / "llama-stack")
def test_circular_symlinks(self):
"""Test XDG functions with circular symbolic links."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create circular symlinks
link1 = temp_path / "link1"
link2 = temp_path / "link2"
try:
link1.symlink_to(link2)
link2.symlink_to(link1)
os.environ["XDG_CONFIG_HOME"] = str(link1)
# Should handle circular symlinks gracefully
result = get_xdg_config_home()
self.assertEqual(result, link1)
except (OSError, NotImplementedError):
# Some systems don't support circular symlinks
self.skipTest("System doesn't support circular symlinks")
def test_broken_symlinks(self):
"""Test XDG functions with broken symbolic links."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create broken symlink
target = temp_path / "nonexistent_target"
link = temp_path / "broken_link"
try:
link.symlink_to(target)
os.environ["XDG_CONFIG_HOME"] = str(link)
# Should handle broken symlinks gracefully
result = get_xdg_config_home()
self.assertEqual(result, link)
except (OSError, NotImplementedError):
# Some systems might not support this
self.skipTest("System doesn't support broken symlinks")
def test_readonly_directories(self):
"""Test XDG functions with read-only directories."""
self.clear_env_vars()
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
readonly_dir = temp_path / "readonly"
readonly_dir.mkdir()
# Make directory read-only
readonly_dir.chmod(0o444)
try:
os.environ["XDG_CONFIG_HOME"] = str(readonly_dir)
# Should still return the path
result = get_xdg_config_home()
self.assertEqual(result, readonly_dir)
finally:
# Restore permissions for cleanup
readonly_dir.chmod(0o755)
def test_permission_denied_access(self):
"""Test XDG functions when permission is denied."""
self.clear_env_vars()
# This test is platform-specific
if platform.system() != "Windows":
# Try to use a system directory that typically requires root
restricted_paths = [
"/root/.config",
"/etc/config",
"/var/root/config",
]
for restricted_path in restricted_paths:
with self.subTest(path=restricted_path):
os.environ["XDG_CONFIG_HOME"] = restricted_path
# Should still return the path even if we can't access it
result = get_xdg_config_home()
self.assertEqual(result, Path(restricted_path))
def test_environment_variable_injection(self):
"""Test XDG functions with environment variable injection attempts."""
self.clear_env_vars()
# Test potential injection attempts
injection_attempts = [
"/tmp/config; rm -rf /",
"/tmp/config && echo 'injected'",
"/tmp/config | cat /etc/passwd",
"/tmp/config`whoami`",
"/tmp/config$(whoami)",
"/tmp/config\necho 'newline'",
]
for injection_attempt in injection_attempts:
with self.subTest(attempt=injection_attempt):
os.environ["XDG_CONFIG_HOME"] = injection_attempt
# Should treat as literal path, not execute
result = get_xdg_config_home()
self.assertEqual(result, Path(injection_attempt))
def test_extremely_nested_paths(self):
"""Test XDG functions with extremely nested directory structures."""
self.clear_env_vars()
# Create deeply nested path
nested_parts = ["level" + str(i) for i in range(100)]
nested_path = "/tmp/" + "/".join(nested_parts)
os.environ["XDG_CONFIG_HOME"] = nested_path
result = get_xdg_config_home()
self.assertEqual(result, Path(nested_path))
def test_empty_and_whitespace_paths(self):
"""Test XDG functions with empty and whitespace-only paths."""
self.clear_env_vars()
empty_values = [
"",
" ",
"\t",
"\n",
"\r\n",
" \t \n ",
]
for empty_value in empty_values:
with self.subTest(value=repr(empty_value)):
os.environ["XDG_CONFIG_HOME"] = empty_value
# Should fall back to default
result = get_xdg_config_home()
self.assertEqual(result, Path.home() / ".config")
def test_path_with_null_bytes(self):
"""Test XDG functions with null bytes in paths."""
self.clear_env_vars()
# Test path with null byte
null_path = "/tmp/config\x00/test"
os.environ["XDG_CONFIG_HOME"] = null_path
# Should handle null bytes (Path will likely raise an error, which is expected)
try:
result = get_xdg_config_home()
# If it doesn't raise an error, check the result
self.assertIsInstance(result, Path)
except (ValueError, OSError):
# This is expected behavior for null bytes
pass
def test_concurrent_access_safety(self):
"""Test that XDG functions are thread-safe."""
import threading
import time
self.clear_env_vars()
results = []
errors = []
def worker(thread_id):
try:
# Each thread sets a different XDG path
os.environ["XDG_CONFIG_HOME"] = f"/tmp/thread_{thread_id}"
# Small delay to increase chance of race conditions
time.sleep(0.01)
config_dir = get_llama_stack_config_dir()
results.append((thread_id, config_dir))
except Exception as e:
errors.append((thread_id, e))
# Start multiple threads
threads = []
for i in range(20):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()
# Wait for all threads
for t in threads:
t.join()
# Check for errors
if errors:
self.fail(f"Thread errors: {errors}")
# Check that we got results from all threads
self.assertEqual(len(results), 20)
def test_filesystem_limits(self):
"""Test XDG functions approaching filesystem limits."""
self.clear_env_vars()
# Test with very long filename (close to 255 char limit)
long_filename = "a" * 240
long_path = f"/tmp/{long_filename}"
os.environ["XDG_CONFIG_HOME"] = long_path
result = get_xdg_config_home()
self.assertEqual(result, Path(long_path))
def test_case_sensitivity(self):
"""Test XDG functions with case sensitivity edge cases."""
self.clear_env_vars()
# Test case variations
case_variations = [
"/tmp/Config",
"/tmp/CONFIG",
"/tmp/config",
"/tmp/Config/MixedCase",
]
for case_path in case_variations:
with self.subTest(path=case_path):
os.environ["XDG_CONFIG_HOME"] = case_path
result = get_xdg_config_home()
self.assertEqual(result, Path(case_path))
def test_ensure_directory_exists_edge_cases(self):
"""Test ensure_directory_exists with edge cases."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Test with file that exists but is not a directory
file_path = temp_path / "file_not_dir"
file_path.touch()
with self.assertRaises(FileExistsError):
ensure_directory_exists(file_path)
# Test with permission denied
if platform.system() != "Windows":
readonly_parent = temp_path / "readonly_parent"
readonly_parent.mkdir()
readonly_parent.chmod(0o444)
try:
nested_path = readonly_parent / "nested"
with self.assertRaises(PermissionError):
ensure_directory_exists(nested_path)
finally:
# Restore permissions for cleanup
readonly_parent.chmod(0o755)
def test_migrate_legacy_directory_edge_cases(self):
"""Test migrate_legacy_directory with edge cases."""
with tempfile.TemporaryDirectory() as temp_dir:
home_dir = Path(temp_dir)
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = home_dir
# Test with legacy directory but no write permissions
legacy_dir = home_dir / ".llama"
legacy_dir.mkdir()
(legacy_dir / "test_file").touch()
# Make home directory read-only
home_dir.chmod(0o444)
try:
# Should handle permission errors gracefully
with patch("builtins.print") as mock_print:
migrate_legacy_directory()
# Should print some information
self.assertTrue(mock_print.called)
finally:
# Restore permissions for cleanup
home_dir.chmod(0o755)
legacy_dir.chmod(0o755)
def test_path_traversal_attempts(self):
"""Test XDG functions with path traversal attempts."""
self.clear_env_vars()
traversal_attempts = [
"/tmp/config/../../../etc/passwd",
"/tmp/config/../../root/.ssh",
"/tmp/config/../../../../../etc/shadow",
"/tmp/config/./../../root",
]
for traversal_attempt in traversal_attempts:
with self.subTest(attempt=traversal_attempt):
os.environ["XDG_CONFIG_HOME"] = traversal_attempt
# Should handle path traversal attempts by treating as literal paths
result = get_xdg_config_home()
self.assertEqual(result, Path(traversal_attempt))
def test_environment_variable_precedence_edge_cases(self):
"""Test environment variable precedence with edge cases."""
self.clear_env_vars()
# Test with both old and new env vars set
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/legacy/path"
os.environ["XDG_CONFIG_HOME"] = "/xdg/path"
# Create fake legacy directory
with tempfile.TemporaryDirectory() as temp_dir:
fake_home = Path(temp_dir)
fake_legacy = fake_home / ".llama"
fake_legacy.mkdir()
(fake_legacy / "test_file").touch()
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
mock_home.return_value = fake_home
# LLAMA_STACK_CONFIG_DIR should take precedence
config_dir = get_llama_stack_config_dir()
self.assertEqual(config_dir, Path("/legacy/path"))
def test_malformed_environment_variables(self):
"""Test XDG functions with malformed environment variables."""
self.clear_env_vars()
malformed_values = [
"not_an_absolute_path",
"~/tilde_not_expanded",
"$HOME/variable_not_expanded",
"relative/path/config",
"./relative/path",
"../parent/path",
]
for malformed_value in malformed_values:
with self.subTest(value=malformed_value):
os.environ["XDG_CONFIG_HOME"] = malformed_value
# Should handle malformed values gracefully
result = get_xdg_config_home()
self.assertIsInstance(result, Path)
if __name__ == "__main__":
unittest.main()