diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md new file mode 100644 index 000000000..a978a6181 --- /dev/null +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -0,0 +1,172 @@ + +# Ollama Distribution + +The `llamastack/distribution-ollama` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| datasetio | `remote::huggingface`, `inline::localfs` | +| eval | `inline::meta-reference` | +| files | `inline::localfs` | +| inference | `remote::ollama` | +| post_training | `inline::huggingface` | +| safety | `inline::llama-guard` | +| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | +| telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | +| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | + + +### Environment Variables + +The following environment variables can be configured: + +- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`) +- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`) +- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`) +- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`) + + +## Prerequisites + +### Ollama Server + +This distribution requires an external Ollama server to be running. You can install and run Ollama by following these steps: + +1. **Install Ollama**: Download and install Ollama from [https://ollama.ai/](https://ollama.ai/) + +2. **Start the Ollama server**: + ```bash + ollama serve + ``` + By default, Ollama serves on `http://127.0.0.1:11434` + +3. **Pull the required models**: + ```bash + # Pull the inference model + ollama pull meta-llama/Llama-3.2-3B-Instruct + + # Pull the embedding model + ollama pull all-minilm:latest + + # (Optional) Pull the safety model for run-with-safety.yaml + ollama pull meta-llama/Llama-Guard-3-1B + ``` + +## Supported Services + +### Inference: Ollama +Uses an external Ollama server for running LLM inference. The server should be accessible at the URL specified in the `OLLAMA_URL` environment variable. + +### Vector IO: FAISS +Provides vector storage capabilities using FAISS for embeddings and similarity search operations. + +### Safety: Llama Guard (Optional) +When using the `run-with-safety.yaml` configuration, provides safety checks using Llama Guard models running on the Ollama server. + +### Agents: Meta Reference +Provides agent execution capabilities using the meta-reference implementation. + +### Post-Training: Hugging Face +Supports model fine-tuning using Hugging Face integration. + +### Tool Runtime +Supports various external tools including: +- Brave Search +- Tavily Search +- RAG Runtime +- Model Context Protocol +- Wolfram Alpha + +## Running Llama Stack with Ollama + +You can do this via Conda or venv (build code), or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=8321 +docker run \ + -it \ + --pull always \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-ollama \ + --config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env OLLAMA_URL=$OLLAMA_URL \ + --env INFERENCE_MODEL=$INFERENCE_MODEL +``` + +### Via Conda + +```bash +llama stack build --template ollama --image-type conda +llama stack run ./run.yaml \ + --port 8321 \ + --env OLLAMA_URL=$OLLAMA_URL \ + --env INFERENCE_MODEL=$INFERENCE_MODEL +``` + +### Via venv + +If you've set up your local development environment, you can also build the image using your local virtual environment. + +```bash +llama stack build --template ollama --image-type venv +llama stack run ./run.yaml \ + --port 8321 \ + --env OLLAMA_URL=$OLLAMA_URL \ + --env INFERENCE_MODEL=$INFERENCE_MODEL +``` + +### Running with Safety + +To enable safety checks, use the `run-with-safety.yaml` configuration: + +```bash +llama stack run ./run-with-safety.yaml \ + --port 8321 \ + --env OLLAMA_URL=$OLLAMA_URL \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env SAFETY_MODEL=$SAFETY_MODEL +``` + +## Example Usage + +Once your Llama Stack server is running with Ollama, you can interact with it using the Llama Stack client: + +```python +from llama_stack_client import LlamaStackClient + +client = LlamaStackClient(base_url="http://localhost:8321") + +# Run inference +response = client.inference.chat_completion( + model_id="meta-llama/Llama-3.2-3B-Instruct", + messages=[{"role": "user", "content": "Hello, how are you?"}], +) +print(response.completion_message.content) +``` + +## Troubleshooting + +### Common Issues + +1. **Connection refused errors**: Ensure your Ollama server is running and accessible at the configured URL. + +2. **Model not found errors**: Make sure you've pulled the required models using `ollama pull `. + +3. **Performance issues**: Consider using more powerful models or adjusting the Ollama server configuration for better performance. + +### Logs + +Check the Ollama server logs for any issues: +```bash +# Ollama logs are typically available in: +# - macOS: ~/Library/Logs/Ollama/ +# - Linux: ~/.ollama/logs/ +``` diff --git a/docs/source/getting_started/xdg_compliance.md b/docs/source/getting_started/xdg_compliance.md new file mode 100644 index 000000000..6f7a32b9c --- /dev/null +++ b/docs/source/getting_started/xdg_compliance.md @@ -0,0 +1,191 @@ +# XDG Base Directory Specification Compliance + +Starting with version 0.2.14, Llama Stack follows the [XDG Base Directory Specification](https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html) for organizing configuration and data files. This provides better integration with modern desktop environments and allows for more flexible customization of where files are stored. + +## Overview + +The XDG Base Directory Specification defines standard locations for different types of application data: + +- **Configuration files** (`XDG_CONFIG_HOME`): User-specific configuration files +- **Data files** (`XDG_DATA_HOME`): User-specific data files that should persist +- **Cache files** (`XDG_CACHE_HOME`): User-specific cache files +- **State files** (`XDG_STATE_HOME`): User-specific state files + +## Directory Mapping + +Llama Stack now uses the following XDG-compliant directory structure: + +| Data Type | XDG Directory | Default Location | Description | +|-----------|---------------|------------------|-------------| +| Configuration | `XDG_CONFIG_HOME` | `~/.config/llama-stack` | Distribution configs, provider configs | +| Data | `XDG_DATA_HOME` | `~/.local/share/llama-stack` | Model checkpoints, persistent files | +| Cache | `XDG_CACHE_HOME` | `~/.cache/llama-stack` | Temporary cache files | +| State | `XDG_STATE_HOME` | `~/.local/state/llama-stack` | Runtime state, databases | + +## Environment Variables + +You can customize the locations by setting these environment variables: + +```bash +# Override the base directories +export XDG_CONFIG_HOME="/custom/config/path" +export XDG_DATA_HOME="/custom/data/path" +export XDG_CACHE_HOME="/custom/cache/path" +export XDG_STATE_HOME="/custom/state/path" + +# Or override specific Llama Stack directories +export SQLITE_STORE_DIR="/custom/database/path" +export FILES_STORAGE_DIR="/custom/files/path" +``` + +## Backwards Compatibility + +Llama Stack maintains full backwards compatibility with existing installations: + +1. **Legacy Environment Variable**: If `LLAMA_STACK_CONFIG_DIR` is set, it will be used for all directories +2. **Legacy Directory Detection**: If `~/.llama` exists and contains data, it will continue to be used +3. **Gradual Migration**: New installations use XDG paths, existing installations continue to work + +## Migration Guide + +### Automatic Migration + +Use the built-in migration command to move from legacy `~/.llama` to XDG-compliant directories: + +```bash +# Preview what would be migrated +llama migrate-xdg --dry-run + +# Perform the migration +llama migrate-xdg +``` + +### Manual Migration + +If you prefer to migrate manually, here's the mapping: + +```bash +# Create XDG directories +mkdir -p ~/.config/llama-stack +mkdir -p ~/.local/share/llama-stack +mkdir -p ~/.local/state/llama-stack + +# Move configuration files +mv ~/.llama/distributions ~/.config/llama-stack/ +mv ~/.llama/providers.d ~/.config/llama-stack/ + +# Move data files +mv ~/.llama/checkpoints ~/.local/share/llama-stack/ + +# Move state files +mv ~/.llama/runtime ~/.local/state/llama-stack/ + +# Clean up empty legacy directory +rmdir ~/.llama +``` + +### Environment Variables in Configurations + +Template configurations now use XDG-compliant environment variables: + +```yaml +# Old format +db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db + +# New format +db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/registry.db +``` + +## Configuration Examples + +### Using Custom XDG Directories + +```bash +# Set custom XDG directories +export XDG_CONFIG_HOME="/opt/llama-stack/config" +export XDG_DATA_HOME="/opt/llama-stack/data" +export XDG_STATE_HOME="/opt/llama-stack/state" + +# Start Llama Stack +llama stack run my-distribution.yaml +``` + +### Using Legacy Directory + +```bash +# Continue using legacy directory +export LLAMA_STACK_CONFIG_DIR="/home/user/.llama" + +# Start Llama Stack +llama stack run my-distribution.yaml +``` + +### Custom Database and File Locations + +```bash +# Override specific directories +export SQLITE_STORE_DIR="/fast/ssd/llama-stack/databases" +export FILES_STORAGE_DIR="/large/disk/llama-stack/files" + +# Start Llama Stack +llama stack run my-distribution.yaml +``` + +## Benefits of XDG Compliance + +1. **Standards Compliance**: Follows established Linux/Unix conventions +2. **Better Organization**: Separates configuration, data, cache, and state files +3. **Flexibility**: Easy to customize storage locations +4. **Backup-Friendly**: Easier to backup just data files or just configuration +5. **Multi-User Support**: Better support for shared systems +6. **Cross-Platform**: Works consistently across different environments + +## Template Updates + +All distribution templates have been updated to use XDG-compliant paths: + +- Database files use `XDG_STATE_HOME` +- Model checkpoints use `XDG_DATA_HOME` +- Configuration files use `XDG_CONFIG_HOME` +- Cache files use `XDG_CACHE_HOME` + +## Troubleshooting + +### Migration Issues + +If you encounter issues during migration: + +1. **Check Permissions**: Ensure you have write permissions to target directories +2. **Disk Space**: Verify sufficient disk space in target locations +3. **Existing Files**: Handle conflicts with existing files in target locations + +### Environment Variable Conflicts + +If you have multiple environment variables set: + +1. `LLAMA_STACK_CONFIG_DIR` takes highest precedence +2. Individual `XDG_*` variables override defaults +3. Fallback to legacy `~/.llama` if it exists +4. Default to XDG standard paths for new installations + +### Debugging Path Resolution + +To see which paths Llama Stack is using: + +```python +from llama_stack.distribution.utils.xdg_utils import ( + get_llama_stack_config_dir, + get_llama_stack_data_dir, + get_llama_stack_state_dir, +) + +print(f"Config: {get_llama_stack_config_dir()}") +print(f"Data: {get_llama_stack_data_dir()}") +print(f"State: {get_llama_stack_state_dir()}") +``` + +## Future Considerations + +- Container deployments will continue to use `/app` or similar paths +- Cloud deployments may use provider-specific storage systems +- The XDG specification primarily applies to local development and single-user systems \ No newline at end of file diff --git a/docs/source/providers/agents/inline_meta-reference.md b/docs/source/providers/agents/inline_meta-reference.md index 5f64f79e1..c67187d57 100644 --- a/docs/source/providers/agents/inline_meta-reference.md +++ b/docs/source/providers/agents/inline_meta-reference.md @@ -16,10 +16,10 @@ Meta's reference implementation of an agent system that can use tools, access ve ```yaml persistence_store: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/agents_store.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/agents_store.db responses_store: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/responses_store.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/responses_store.db ``` diff --git a/docs/source/providers/datasetio/inline_localfs.md b/docs/source/providers/datasetio/inline_localfs.md index 87a0c795c..db5d2c38d 100644 --- a/docs/source/providers/datasetio/inline_localfs.md +++ b/docs/source/providers/datasetio/inline_localfs.md @@ -15,7 +15,7 @@ Local filesystem-based dataset I/O provider for reading and writing datasets to ```yaml kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/localfs_datasetio.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/localfs_datasetio.db ``` diff --git a/docs/source/providers/datasetio/remote_huggingface.md b/docs/source/providers/datasetio/remote_huggingface.md index 3711f7396..b268ef75d 100644 --- a/docs/source/providers/datasetio/remote_huggingface.md +++ b/docs/source/providers/datasetio/remote_huggingface.md @@ -15,7 +15,7 @@ HuggingFace datasets provider for accessing and managing datasets from the Huggi ```yaml kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/huggingface_datasetio.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/huggingface_datasetio.db ``` diff --git a/docs/source/providers/eval/inline_meta-reference.md b/docs/source/providers/eval/inline_meta-reference.md index 606883c72..d0b6a835d 100644 --- a/docs/source/providers/eval/inline_meta-reference.md +++ b/docs/source/providers/eval/inline_meta-reference.md @@ -15,7 +15,7 @@ Meta's reference implementation of evaluation tasks with support for multiple la ```yaml kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/meta_reference_eval.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/meta_reference_eval.db ``` diff --git a/docs/source/providers/files/inline_localfs.md b/docs/source/providers/files/inline_localfs.md index 54c489c7d..ea6795ffb 100644 --- a/docs/source/providers/files/inline_localfs.md +++ b/docs/source/providers/files/inline_localfs.md @@ -15,10 +15,10 @@ Local filesystem-based file storage provider for managing files and documents lo ## Sample Configuration ```yaml -storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/dummy/files} +storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy/files} metadata_store: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/files_metadata.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/files_metadata.db ``` diff --git a/docs/source/providers/telemetry/inline_meta-reference.md b/docs/source/providers/telemetry/inline_meta-reference.md index 3e5f4b842..2090c2726 100644 --- a/docs/source/providers/telemetry/inline_meta-reference.md +++ b/docs/source/providers/telemetry/inline_meta-reference.md @@ -11,14 +11,14 @@ Meta's reference implementation of telemetry and observability using OpenTelemet | `otel_exporter_otlp_endpoint` | `str \| None` | No | | The OpenTelemetry collector endpoint URL (base URL for traces, metrics, and logs). If not set, the SDK will use OTEL_EXPORTER_OTLP_ENDPOINT environment variable. | | `service_name` | `` | No | ​ | The service name to use for telemetry | | `sinks` | `list[inline.telemetry.meta_reference.config.TelemetrySink` | No | [, ] | List of telemetry sinks to enable (possible values: otel_trace, otel_metric, sqlite, console) | -| `sqlite_db_path` | `` | No | ~/.llama/runtime/trace_store.db | The path to the SQLite database to use for storing traces | +| `sqlite_db_path` | `` | No | ${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/runtime/trace_store.db | The path to the SQLite database to use for storing traces | ## Sample Configuration ```yaml service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" sinks: ${env.TELEMETRY_SINKS:=console,sqlite} -sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/trace_store.db +sqlite_db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} ``` diff --git a/docs/source/providers/vector_io/inline_faiss.md b/docs/source/providers/vector_io/inline_faiss.md index bcff66f3f..f07a63928 100644 --- a/docs/source/providers/vector_io/inline_faiss.md +++ b/docs/source/providers/vector_io/inline_faiss.md @@ -44,7 +44,7 @@ more details about Faiss in general. ```yaml kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/faiss_store.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/faiss_store.db ``` diff --git a/docs/source/providers/vector_io/inline_meta-reference.md b/docs/source/providers/vector_io/inline_meta-reference.md index 0aac445bd..f089045ef 100644 --- a/docs/source/providers/vector_io/inline_meta-reference.md +++ b/docs/source/providers/vector_io/inline_meta-reference.md @@ -15,7 +15,7 @@ Meta's reference implementation of a vector database. ```yaml kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/faiss_store.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/faiss_store.db ``` diff --git a/docs/source/providers/vector_io/inline_milvus.md b/docs/source/providers/vector_io/inline_milvus.md index 3b3aad3fc..145f50c60 100644 --- a/docs/source/providers/vector_io/inline_milvus.md +++ b/docs/source/providers/vector_io/inline_milvus.md @@ -17,10 +17,10 @@ Please refer to the remote provider documentation. ## Sample Configuration ```yaml -db_path: ${env.MILVUS_DB_PATH:=~/.llama/dummy}/milvus.db +db_path: ${env.MILVUS_DB_PATH:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/milvus.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_registry.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/milvus_registry.db ``` diff --git a/docs/source/providers/vector_io/inline_qdrant.md b/docs/source/providers/vector_io/inline_qdrant.md index 63e2d81d8..cb9785819 100644 --- a/docs/source/providers/vector_io/inline_qdrant.md +++ b/docs/source/providers/vector_io/inline_qdrant.md @@ -55,7 +55,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta ## Sample Configuration ```yaml -path: ${env.QDRANT_PATH:=~/.llama/~/.llama/dummy}/qdrant.db +path: ${env.QDRANT_PATH:=~/.llama/${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/qdrant.db ``` diff --git a/docs/source/providers/vector_io/inline_sqlite-vec.md b/docs/source/providers/vector_io/inline_sqlite-vec.md index ae7c45b21..8d09b090b 100644 --- a/docs/source/providers/vector_io/inline_sqlite-vec.md +++ b/docs/source/providers/vector_io/inline_sqlite-vec.md @@ -211,10 +211,10 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f ## Sample Configuration ```yaml -db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db +db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec_registry.db ``` diff --git a/docs/source/providers/vector_io/inline_sqlite_vec.md b/docs/source/providers/vector_io/inline_sqlite_vec.md index 7e14bb8bd..8d5ab14c9 100644 --- a/docs/source/providers/vector_io/inline_sqlite_vec.md +++ b/docs/source/providers/vector_io/inline_sqlite_vec.md @@ -16,10 +16,10 @@ Please refer to the sqlite-vec provider documentation. ## Sample Configuration ```yaml -db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db +db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec_registry.db ``` diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index 6734d8315..848dc9391 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -126,7 +126,7 @@ uri: ${env.MILVUS_ENDPOINT} token: ${env.MILVUS_TOKEN} kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/milvus_remote_registry.db ``` diff --git a/docs/source/providers/vector_io/remote_pgvector.md b/docs/source/providers/vector_io/remote_pgvector.md index 74f588a13..88e4903a7 100644 --- a/docs/source/providers/vector_io/remote_pgvector.md +++ b/docs/source/providers/vector_io/remote_pgvector.md @@ -52,7 +52,7 @@ user: ${env.PGVECTOR_USER} password: ${env.PGVECTOR_PASSWORD} kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/pgvector_registry.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/pgvector_registry.db ``` diff --git a/docs/source/providers/vector_io/remote_weaviate.md b/docs/source/providers/vector_io/remote_weaviate.md index d930515d5..de56680cb 100644 --- a/docs/source/providers/vector_io/remote_weaviate.md +++ b/docs/source/providers/vector_io/remote_weaviate.md @@ -38,7 +38,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more ```yaml kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/weaviate_registry.db ``` diff --git a/llama_stack/cli/llama.py b/llama_stack/cli/llama.py index 433b311e7..060559fd6 100644 --- a/llama_stack/cli/llama.py +++ b/llama_stack/cli/llama.py @@ -7,6 +7,7 @@ import argparse from .download import Download +from .migrate_xdg import MigrateXDG from .model import ModelParser from .stack import StackParser from .stack.utils import print_subcommand_description @@ -34,6 +35,7 @@ class LlamaCLIParser: StackParser.create(subparsers) Download.create(subparsers) VerifyDownload.create(subparsers) + MigrateXDG.create(subparsers) print_subcommand_description(self.parser, subparsers) diff --git a/llama_stack/cli/migrate_xdg.py b/llama_stack/cli/migrate_xdg.py new file mode 100644 index 000000000..99c52327e --- /dev/null +++ b/llama_stack/cli/migrate_xdg.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import shutil +import sys +from pathlib import Path + +from llama_stack.distribution.utils.xdg_utils import ( + get_llama_stack_config_dir, + get_llama_stack_data_dir, + get_llama_stack_state_dir, +) + +from .subcommand import Subcommand + + +class MigrateXDG(Subcommand): + """CLI command for migrating from legacy ~/.llama to XDG-compliant directories.""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "migrate-xdg", + prog="llama migrate-xdg", + description="Migrate from legacy ~/.llama to XDG-compliant directories", + formatter_class=argparse.RawTextHelpFormatter, + ) + + self.parser.add_argument( + "--dry-run", action="store_true", help="Show what would be done without actually moving files" + ) + + self.parser.set_defaults(func=self._run_migrate_xdg_cmd) + + @staticmethod + def create(subparsers: argparse._SubParsersAction): + return MigrateXDG(subparsers) + + def _run_migrate_xdg_cmd(self, args: argparse.Namespace) -> None: + """Run the migrate-xdg command.""" + if not migrate_to_xdg(dry_run=args.dry_run): + sys.exit(1) + + +def migrate_to_xdg(dry_run: bool = False) -> bool: + """ + Migrate from legacy ~/.llama to XDG-compliant directories. + + Args: + dry_run: If True, only show what would be done without actually moving files + + Returns: + bool: True if migration was successful or not needed, False otherwise + """ + legacy_path = Path.home() / ".llama" + + if not legacy_path.exists(): + print("No legacy ~/.llama directory found. Nothing to migrate.") + return True + + # Check if we're already using XDG paths + config_dir = get_llama_stack_config_dir() + data_dir = get_llama_stack_data_dir() + state_dir = get_llama_stack_state_dir() + + if str(config_dir) == str(legacy_path): + print("Already using legacy directory. No migration needed.") + return True + + print(f"Found legacy directory at: {legacy_path}") + print("Will migrate to XDG-compliant directories:") + print(f" Config: {config_dir}") + print(f" Data: {data_dir}") + print(f" State: {state_dir}") + print() + + # Define migration mapping + migrations = [ + # (source_subdir, target_base_dir, description) + ("distributions", config_dir, "Distribution configurations"), + ("providers.d", config_dir, "External provider configurations"), + ("checkpoints", data_dir, "Model checkpoints"), + ("runtime", state_dir, "Runtime state files"), + ] + + # Check what needs to be migrated + items_to_migrate = [] + for subdir, target_base, description in migrations: + source_path = legacy_path / subdir + if source_path.exists(): + target_path = target_base / subdir + items_to_migrate.append((source_path, target_path, description)) + + if not items_to_migrate: + print("No items found to migrate.") + return True + + print("Items to migrate:") + for source_path, target_path, description in items_to_migrate: + print(f" {description}: {source_path} -> {target_path}") + + if dry_run: + print("\nDry run mode: No files will be moved.") + return True + + # Ask for confirmation + response = input("\nDo you want to proceed with the migration? (y/N): ") + if response.lower() not in ["y", "yes"]: + print("Migration cancelled.") + return False + + # Perform the migration + print("\nMigrating files...") + + for source_path, target_path, description in items_to_migrate: + try: + # Create target directory if it doesn't exist + target_path.parent.mkdir(parents=True, exist_ok=True) + + # Check if target already exists + if target_path.exists(): + print(f" Warning: Target already exists: {target_path}") + print(f" Skipping {description}") + continue + + # Move the directory + shutil.move(str(source_path), str(target_path)) + print(f" Moved {description}: {source_path} -> {target_path}") + + except Exception as e: + print(f" Error migrating {description}: {e}") + return False + + # Check if legacy directory is now empty (except for hidden files) + remaining_items = [item for item in legacy_path.iterdir() if not item.name.startswith(".")] + if not remaining_items: + print(f"\nMigration complete! Legacy directory {legacy_path} is now empty.") + response = input("Remove empty legacy directory? (y/N): ") + if response.lower() in ["y", "yes"]: + try: + shutil.rmtree(legacy_path) + print(f"Removed empty legacy directory: {legacy_path}") + except Exception as e: + print(f"Could not remove legacy directory: {e}") + else: + print(f"\nMigration complete! Some items remain in legacy directory: {remaining_items}") + + print("\nMigration successful!") + print("You may need to update any custom scripts or configurations that reference the old paths.") + return True + + +def main(): + parser = argparse.ArgumentParser(description="Migrate from legacy ~/.llama to XDG-compliant directories") + parser.add_argument("--dry-run", action="store_true", help="Show what would be done without actually moving files") + + args = parser.parse_args() + + if not migrate_to_xdg(dry_run=args.dry_run): + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/llama_stack/distribution/utils/config_dirs.py b/llama_stack/distribution/utils/config_dirs.py index c3e520f28..82f412aa4 100644 --- a/llama_stack/distribution/utils/config_dirs.py +++ b/llama_stack/distribution/utils/config_dirs.py @@ -7,12 +7,35 @@ import os from pathlib import Path -LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/"))) +from .xdg_utils import ( + get_llama_stack_config_dir, + get_llama_stack_data_dir, + get_llama_stack_state_dir, +) +# Base directory for all llama-stack configuration +# This now uses XDG-compliant paths with backwards compatibility +LLAMA_STACK_CONFIG_DIR = get_llama_stack_config_dir() + +# Distribution configurations - stored in config directory DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions" -DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints" +# Model checkpoints - stored in data directory (persistent data) +DEFAULT_CHECKPOINT_DIR = get_llama_stack_data_dir() / "checkpoints" -RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime" +# Runtime data - stored in state directory +RUNTIME_BASE_DIR = get_llama_stack_state_dir() / "runtime" +# External providers - stored in config directory EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d" + +# Legacy compatibility: if the legacy environment variable is set, use it for all paths +# This ensures that existing installations continue to work +legacy_config_dir = os.getenv("LLAMA_STACK_CONFIG_DIR") +if legacy_config_dir: + legacy_base = Path(legacy_config_dir) + LLAMA_STACK_CONFIG_DIR = legacy_base + DISTRIBS_BASE_DIR = legacy_base / "distributions" + DEFAULT_CHECKPOINT_DIR = legacy_base / "checkpoints" + RUNTIME_BASE_DIR = legacy_base / "runtime" + EXTERNAL_PROVIDERS_DIR = legacy_base / "providers.d" diff --git a/llama_stack/distribution/utils/xdg_utils.py b/llama_stack/distribution/utils/xdg_utils.py new file mode 100644 index 000000000..873cafb01 --- /dev/null +++ b/llama_stack/distribution/utils/xdg_utils.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from pathlib import Path + + +def get_xdg_config_home() -> Path: + """ + Get the XDG config home directory. + + Returns: + Path: XDG_CONFIG_HOME if set, otherwise ~/.config + """ + return Path(os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))) + + +def get_xdg_data_home() -> Path: + """ + Get the XDG data home directory. + + Returns: + Path: XDG_DATA_HOME if set, otherwise ~/.local/share + """ + return Path(os.environ.get("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))) + + +def get_xdg_cache_home() -> Path: + """ + Get the XDG cache home directory. + + Returns: + Path: XDG_CACHE_HOME if set, otherwise ~/.cache + """ + return Path(os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))) + + +def get_xdg_state_home() -> Path: + """ + Get the XDG state home directory. + + Returns: + Path: XDG_STATE_HOME if set, otherwise ~/.local/state + """ + return Path(os.environ.get("XDG_STATE_HOME", os.path.expanduser("~/.local/state"))) + + +def get_llama_stack_config_dir() -> Path: + """ + Get the llama-stack configuration directory. + + This function provides backwards compatibility by checking for the legacy + LLAMA_STACK_CONFIG_DIR environment variable first, then falling back to + XDG-compliant paths. + + Returns: + Path: Configuration directory for llama-stack + """ + # Check for legacy environment variable first for backwards compatibility + legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR") + if legacy_dir: + return Path(legacy_dir) + + # Check if legacy ~/.llama directory exists and contains data + legacy_path = Path.home() / ".llama" + if legacy_path.exists() and any(legacy_path.iterdir()): + return legacy_path + + # Use XDG-compliant path + return get_xdg_config_home() / "llama-stack" + + +def get_llama_stack_data_dir() -> Path: + """ + Get the llama-stack data directory. + + This is used for persistent data like model checkpoints. + + Returns: + Path: Data directory for llama-stack + """ + # Check for legacy environment variable first for backwards compatibility + legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR") + if legacy_dir: + return Path(legacy_dir) + + # Check if legacy ~/.llama directory exists and contains data + legacy_path = Path.home() / ".llama" + if legacy_path.exists() and any(legacy_path.iterdir()): + return legacy_path + + # Use XDG-compliant path + return get_xdg_data_home() / "llama-stack" + + +def get_llama_stack_cache_dir() -> Path: + """ + Get the llama-stack cache directory. + + This is used for temporary/cache data. + + Returns: + Path: Cache directory for llama-stack + """ + # Check for legacy environment variable first for backwards compatibility + legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR") + if legacy_dir: + return Path(legacy_dir) + + # Check if legacy ~/.llama directory exists and contains data + legacy_path = Path.home() / ".llama" + if legacy_path.exists() and any(legacy_path.iterdir()): + return legacy_path + + # Use XDG-compliant path + return get_xdg_cache_home() / "llama-stack" + + +def get_llama_stack_state_dir() -> Path: + """ + Get the llama-stack state directory. + + This is used for runtime state data. + + Returns: + Path: State directory for llama-stack + """ + # Check for legacy environment variable first for backwards compatibility + legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR") + if legacy_dir: + return Path(legacy_dir) + + # Check if legacy ~/.llama directory exists and contains data + legacy_path = Path.home() / ".llama" + if legacy_path.exists() and any(legacy_path.iterdir()): + return legacy_path + + # Use XDG-compliant path + return get_xdg_state_home() / "llama-stack" + + +def get_xdg_compliant_path(path_type: str, subdirectory: str | None = None, legacy_fallback: bool = True) -> Path: + """ + Get an XDG-compliant path for a given type. + + Args: + path_type: Type of path ('config', 'data', 'cache', 'state') + subdirectory: Optional subdirectory within the base path + legacy_fallback: Whether to check for legacy ~/.llama directory + + Returns: + Path: XDG-compliant path + + Raises: + ValueError: If path_type is not recognized + """ + path_map = { + "config": get_llama_stack_config_dir, + "data": get_llama_stack_data_dir, + "cache": get_llama_stack_cache_dir, + "state": get_llama_stack_state_dir, + } + + if path_type not in path_map: + raise ValueError(f"Unknown path type: {path_type}. Must be one of: {list(path_map.keys())}") + + base_path = path_map[path_type]() + + if subdirectory: + return base_path / subdirectory + + return base_path + + +def migrate_legacy_directory() -> bool: + """ + Migrate from legacy ~/.llama directory to XDG-compliant directories. + + This function helps users migrate their existing data to the new + XDG-compliant structure. + + Returns: + bool: True if migration was successful or not needed, False otherwise + """ + legacy_path = Path.home() / ".llama" + + if not legacy_path.exists(): + return True # No migration needed + + print(f"Found legacy directory at {legacy_path}") + print("Consider migrating to XDG-compliant directories:") + print(f" Config: {get_llama_stack_config_dir()}") + print(f" Data: {get_llama_stack_data_dir()}") + print(f" Cache: {get_llama_stack_cache_dir()}") + print(f" State: {get_llama_stack_state_dir()}") + print("Migration can be done by moving the appropriate subdirectories.") + + return True + + +def ensure_directory_exists(path: Path) -> None: + """ + Ensure a directory exists, creating it if necessary. + + Args: + path: Path to the directory + """ + path.mkdir(parents=True, exist_ok=True) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 76d789d07..8a30abacd 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -4,6 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + import asyncio import base64 diff --git a/llama_stack/strong_typing/auxiliary.py b/llama_stack/strong_typing/auxiliary.py index 965ffa079..2ac7e9010 100644 --- a/llama_stack/strong_typing/auxiliary.py +++ b/llama_stack/strong_typing/auxiliary.py @@ -12,23 +12,12 @@ Type-safe data interchange for Python data classes. import dataclasses import sys +from collections.abc import Callable from dataclasses import is_dataclass -from typing import Callable, Dict, Optional, Type, TypeVar, Union, overload - -if sys.version_info >= (3, 9): - from typing import Annotated as Annotated -else: - from typing_extensions import Annotated as Annotated - -if sys.version_info >= (3, 10): - from typing import TypeAlias as TypeAlias -else: - from typing_extensions import TypeAlias as TypeAlias - -if sys.version_info >= (3, 11): - from typing import dataclass_transform as dataclass_transform -else: - from typing_extensions import dataclass_transform as dataclass_transform +from typing import Annotated as Annotated +from typing import TypeAlias as TypeAlias +from typing import TypeVar, overload +from typing import dataclass_transform as dataclass_transform T = TypeVar("T") @@ -56,17 +45,17 @@ class CompactDataClass: @overload -def typeannotation(cls: Type[T], /) -> Type[T]: ... +def typeannotation(cls: type[T], /) -> type[T]: ... @overload -def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[Type[T]], Type[T]]: ... +def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[type[T]], type[T]]: ... @dataclass_transform(eq_default=True, order_default=False) def typeannotation( - cls: Optional[Type[T]] = None, *, eq: bool = True, order: bool = False -) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: + cls: type[T] | None = None, *, eq: bool = True, order: bool = False +) -> type[T] | Callable[[type[T]], type[T]]: """ Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. @@ -76,7 +65,7 @@ def typeannotation( :returns: A data-class type, or a wrapper for data-class types. """ - def wrap(cls: Type[T]) -> Type[T]: + def wrap(cls: type[T]) -> type[T]: # mypy fails to equate bound-y functions (first argument interpreted as # the bound object) with class methods, hence the `ignore` directive. cls.__repr__ = _compact_dataclass_repr # type: ignore[method-assign] @@ -179,41 +168,41 @@ class SpecialConversion: "Indicates that the annotated type is subject to custom conversion rules." -int8: TypeAlias = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)] -int16: TypeAlias = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)] -int32: TypeAlias = Annotated[ +type int8 = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)] +type int16 = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)] +type int32 = Annotated[ int, Signed(True), Storage(4), IntegerRange(-2147483648, 2147483647), ] -int64: TypeAlias = Annotated[ +type int64 = Annotated[ int, Signed(True), Storage(8), IntegerRange(-9223372036854775808, 9223372036854775807), ] -uint8: TypeAlias = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)] -uint16: TypeAlias = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)] -uint32: TypeAlias = Annotated[ +type uint8 = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)] +type uint16 = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)] +type uint32 = Annotated[ int, Signed(False), Storage(4), IntegerRange(0, 4294967295), ] -uint64: TypeAlias = Annotated[ +type uint64 = Annotated[ int, Signed(False), Storage(8), IntegerRange(0, 18446744073709551615), ] -float32: TypeAlias = Annotated[float, Storage(4)] -float64: TypeAlias = Annotated[float, Storage(8)] +type float32 = Annotated[float, Storage(4)] +type float64 = Annotated[float, Storage(8)] # maps globals of type Annotated[T, ...] defined in this module to their string names -_auxiliary_types: Dict[object, str] = {} +_auxiliary_types: dict[object, str] = {} module = sys.modules[__name__] for var in dir(module): typ = getattr(module, var) @@ -222,7 +211,7 @@ for var in dir(module): _auxiliary_types[typ] = var -def get_auxiliary_format(data_type: object) -> Optional[str]: +def get_auxiliary_format(data_type: object) -> str | None: "Returns the JSON format string corresponding to an auxiliary type." return _auxiliary_types.get(data_type) diff --git a/llama_stack/strong_typing/classdef.py b/llama_stack/strong_typing/classdef.py index 5ead886d4..cbb51529c 100644 --- a/llama_stack/strong_typing/classdef.py +++ b/llama_stack/strong_typing/classdef.py @@ -12,12 +12,11 @@ import enum import ipaddress import math import re -import sys import types import typing import uuid from dataclasses import dataclass -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Literal, TypeVar, Union from .auxiliary import ( Alias, @@ -40,57 +39,57 @@ T = TypeVar("T") @dataclass class JsonSchemaNode: - title: Optional[str] - description: Optional[str] + title: str | None + description: str | None @dataclass class JsonSchemaType(JsonSchemaNode): type: str - format: Optional[str] + format: str | None @dataclass class JsonSchemaBoolean(JsonSchemaType): type: Literal["boolean"] - const: Optional[bool] - default: Optional[bool] - examples: Optional[List[bool]] + const: bool | None + default: bool | None + examples: list[bool] | None @dataclass class JsonSchemaInteger(JsonSchemaType): type: Literal["integer"] - const: Optional[int] - default: Optional[int] - examples: Optional[List[int]] - enum: Optional[List[int]] - minimum: Optional[int] - maximum: Optional[int] + const: int | None + default: int | None + examples: list[int] | None + enum: list[int] | None + minimum: int | None + maximum: int | None @dataclass class JsonSchemaNumber(JsonSchemaType): type: Literal["number"] - const: Optional[float] - default: Optional[float] - examples: Optional[List[float]] - minimum: Optional[float] - maximum: Optional[float] - exclusiveMinimum: Optional[float] - exclusiveMaximum: Optional[float] - multipleOf: Optional[float] + const: float | None + default: float | None + examples: list[float] | None + minimum: float | None + maximum: float | None + exclusiveMinimum: float | None + exclusiveMaximum: float | None + multipleOf: float | None @dataclass class JsonSchemaString(JsonSchemaType): type: Literal["string"] - const: Optional[str] - default: Optional[str] - examples: Optional[List[str]] - enum: Optional[List[str]] - minLength: Optional[int] - maxLength: Optional[int] + const: str | None + default: str | None + examples: list[str] | None + enum: list[str] | None + minLength: int | None + maxLength: int | None @dataclass @@ -102,9 +101,9 @@ class JsonSchemaArray(JsonSchemaType): @dataclass class JsonSchemaObject(JsonSchemaType): type: Literal["object"] - properties: Optional[Dict[str, "JsonSchemaAny"]] - additionalProperties: Optional[bool] - required: Optional[List[str]] + properties: dict[str, "JsonSchemaAny"] | None + additionalProperties: bool | None + required: list[str] | None @dataclass @@ -114,24 +113,24 @@ class JsonSchemaRef(JsonSchemaNode): @dataclass class JsonSchemaAllOf(JsonSchemaNode): - allOf: List["JsonSchemaAny"] + allOf: list["JsonSchemaAny"] @dataclass class JsonSchemaAnyOf(JsonSchemaNode): - anyOf: List["JsonSchemaAny"] + anyOf: list["JsonSchemaAny"] @dataclass class Discriminator: propertyName: str - mapping: Dict[str, str] + mapping: dict[str, str] @dataclass class JsonSchemaOneOf(JsonSchemaNode): - oneOf: List["JsonSchemaAny"] - discriminator: Optional[Discriminator] + oneOf: list["JsonSchemaAny"] + discriminator: Discriminator | None JsonSchemaAny = Union[ @@ -149,10 +148,10 @@ JsonSchemaAny = Union[ @dataclass class JsonSchemaTopLevelObject(JsonSchemaObject): schema: Annotated[str, Alias("$schema")] - definitions: Optional[Dict[str, JsonSchemaAny]] + definitions: dict[str, JsonSchemaAny] | None -def integer_range_to_type(min_value: float, max_value: float) -> type: +def integer_range_to_type(min_value: float, max_value: float) -> Any: if min_value >= -(2**15) and max_value < 2**15: return int16 elif min_value >= -(2**31) and max_value < 2**31: @@ -173,11 +172,11 @@ def enum_safe_name(name: str) -> str: def enum_values_to_type( module: types.ModuleType, name: str, - values: Dict[str, Any], - title: Optional[str] = None, - description: Optional[str] = None, -) -> Type[enum.Enum]: - enum_class: Type[enum.Enum] = enum.Enum(name, values) # type: ignore + values: dict[str, Any], + title: str | None = None, + description: str | None = None, +) -> type[enum.Enum]: + enum_class: type[enum.Enum] = enum.Enum(name, values) # type: ignore # assign the newly created type to the same module where the defining class is enum_class.__module__ = module.__name__ @@ -330,7 +329,7 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode type_def = node_to_typedef(module, context, node.items) if type_def.default is not dataclasses.MISSING: raise TypeError("disallowed: `default` for array element type") - list_type = List[(type_def.type,)] # type: ignore + list_type = list[(type_def.type,)] # type: ignore return TypeDef(list_type, dataclasses.MISSING) elif isinstance(node, JsonSchemaObject): @@ -344,8 +343,8 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode class_name = context - fields: List[Tuple[str, Any, dataclasses.Field]] = [] - params: Dict[str, DocstringParam] = {} + fields: list[tuple[str, Any, dataclasses.Field]] = [] + params: dict[str, DocstringParam] = {} for prop_name, prop_node in node.properties.items(): type_def = node_to_typedef(module, f"{class_name}__{prop_name}", prop_node) if prop_name in required: @@ -358,10 +357,7 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode params[prop_name] = DocstringParam(prop_name, prop_desc) fields.sort(key=lambda t: t[2].default is not dataclasses.MISSING) - if sys.version_info >= (3, 12): - class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__) - else: - class_type = dataclasses.make_dataclass(class_name, fields, namespace={"__module__": module.__name__}) + class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__) class_type.__doc__ = str( Docstring( short_description=node.title, @@ -388,7 +384,7 @@ class SchemaFlatteningOptions: recursive: bool = False -def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions] = None) -> Schema: +def flatten_schema(schema: Schema, *, options: SchemaFlatteningOptions | None = None) -> Schema: top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema)) flattener = SchemaFlattener(options) obj = flattener.flatten(top_node) @@ -398,7 +394,7 @@ def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions] class SchemaFlattener: options: SchemaFlatteningOptions - def __init__(self, options: Optional[SchemaFlatteningOptions] = None) -> None: + def __init__(self, options: SchemaFlatteningOptions | None = None) -> None: self.options = options or SchemaFlatteningOptions() def flatten(self, source_node: JsonSchemaObject) -> JsonSchemaObject: @@ -406,10 +402,10 @@ class SchemaFlattener: return source_node source_props = source_node.properties or {} - target_props: Dict[str, JsonSchemaAny] = {} + target_props: dict[str, JsonSchemaAny] = {} source_reqs = source_node.required or [] - target_reqs: List[str] = [] + target_reqs: list[str] = [] for name, prop in source_props.items(): if not isinstance(prop, JsonSchemaObject): diff --git a/llama_stack/strong_typing/core.py b/llama_stack/strong_typing/core.py index 501b6a5db..5f3764aeb 100644 --- a/llama_stack/strong_typing/core.py +++ b/llama_stack/strong_typing/core.py @@ -10,7 +10,7 @@ Type-safe data interchange for Python data classes. :see: https://github.com/hunyadi/strong_typing """ -from typing import Dict, List, Union +from typing import Union class JsonObject: @@ -28,8 +28,8 @@ JsonType = Union[ int, float, str, - Dict[str, "JsonType"], - List["JsonType"], + dict[str, "JsonType"], + list["JsonType"], ] # a JSON type that cannot contain `null` values @@ -38,9 +38,9 @@ StrictJsonType = Union[ int, float, str, - Dict[str, "StrictJsonType"], - List["StrictJsonType"], + dict[str, "StrictJsonType"], + list["StrictJsonType"], ] # a meta-type that captures the object type in a JSON schema -Schema = Dict[str, JsonType] +Schema = dict[str, JsonType] diff --git a/llama_stack/strong_typing/deserializer.py b/llama_stack/strong_typing/deserializer.py index 883590862..66f7b6e24 100644 --- a/llama_stack/strong_typing/deserializer.py +++ b/llama_stack/strong_typing/deserializer.py @@ -20,19 +20,14 @@ import ipaddress import sys import typing import uuid +from collections.abc import Callable from types import ModuleType from typing import ( Any, - Callable, - Dict, Generic, - List, Literal, NamedTuple, Optional, - Set, - Tuple, - Type, TypeVar, Union, ) @@ -70,7 +65,7 @@ V = TypeVar("V") class Deserializer(abc.ABC, Generic[T]): "Parses a JSON value into a Python type." - def build(self, context: Optional[ModuleType]) -> None: + def build(self, context: ModuleType | None) -> None: """ Creates auxiliary parsers that this parser is depending on. @@ -203,19 +198,19 @@ class IPv6Deserializer(Deserializer[ipaddress.IPv6Address]): return ipaddress.IPv6Address(data) -class ListDeserializer(Deserializer[List[T]]): +class ListDeserializer(Deserializer[list[T]]): "Recursively de-serializes a JSON array into a Python `list`." - item_type: Type[T] + item_type: type[T] item_parser: Deserializer - def __init__(self, item_type: Type[T]) -> None: + def __init__(self, item_type: type[T]) -> None: self.item_type = item_type - def build(self, context: Optional[ModuleType]) -> None: + def build(self, context: ModuleType | None) -> None: self.item_parser = _get_deserializer(self.item_type, context) - def parse(self, data: JsonType) -> List[T]: + def parse(self, data: JsonType) -> list[T]: if not isinstance(data, list): type_name = python_type_to_str(self.item_type) raise JsonTypeError(f"type `List[{type_name}]` expects JSON `array` data but instead received: {data}") @@ -223,19 +218,19 @@ class ListDeserializer(Deserializer[List[T]]): return [self.item_parser.parse(item) for item in data] -class DictDeserializer(Deserializer[Dict[K, V]]): +class DictDeserializer(Deserializer[dict[K, V]]): "Recursively de-serializes a JSON object into a Python `dict`." - key_type: Type[K] - value_type: Type[V] + key_type: type[K] + value_type: type[V] value_parser: Deserializer[V] - def __init__(self, key_type: Type[K], value_type: Type[V]) -> None: + def __init__(self, key_type: type[K], value_type: type[V]) -> None: self.key_type = key_type self.value_type = value_type self._check_key_type() - def build(self, context: Optional[ModuleType]) -> None: + def build(self, context: ModuleType | None) -> None: self.value_parser = _get_deserializer(self.value_type, context) def _check_key_type(self) -> None: @@ -264,48 +259,48 @@ class DictDeserializer(Deserializer[Dict[K, V]]): value_type_name = python_type_to_str(self.value_type) return f"Dict[{key_type_name}, {value_type_name}]" - def parse(self, data: JsonType) -> Dict[K, V]: + def parse(self, data: JsonType) -> dict[K, V]: if not isinstance(data, dict): raise JsonTypeError( f"`type `{self.container_type}` expects JSON `object` data but instead received: {data}" ) - return dict( - (self.key_type(key), self.value_parser.parse(value)) # type: ignore[call-arg] + return { + self.key_type(key): self.value_parser.parse(value) # type: ignore[call-arg] for key, value in data.items() - ) + } -class SetDeserializer(Deserializer[Set[T]]): +class SetDeserializer(Deserializer[set[T]]): "Recursively de-serializes a JSON list into a Python `set`." - member_type: Type[T] + member_type: type[T] member_parser: Deserializer - def __init__(self, member_type: Type[T]) -> None: + def __init__(self, member_type: type[T]) -> None: self.member_type = member_type - def build(self, context: Optional[ModuleType]) -> None: + def build(self, context: ModuleType | None) -> None: self.member_parser = _get_deserializer(self.member_type, context) - def parse(self, data: JsonType) -> Set[T]: + def parse(self, data: JsonType) -> set[T]: if not isinstance(data, list): type_name = python_type_to_str(self.member_type) raise JsonTypeError(f"type `Set[{type_name}]` expects JSON `array` data but instead received: {data}") - return set(self.member_parser.parse(item) for item in data) + return {self.member_parser.parse(item) for item in data} -class TupleDeserializer(Deserializer[Tuple[Any, ...]]): +class TupleDeserializer(Deserializer[tuple[Any, ...]]): "Recursively de-serializes a JSON list into a Python `tuple`." - item_types: Tuple[Type[Any], ...] - item_parsers: Tuple[Deserializer[Any], ...] + item_types: tuple[type[Any], ...] + item_parsers: tuple[Deserializer[Any], ...] - def __init__(self, item_types: Tuple[Type[Any], ...]) -> None: + def __init__(self, item_types: tuple[type[Any], ...]) -> None: self.item_types = item_types - def build(self, context: Optional[ModuleType]) -> None: + def build(self, context: ModuleType | None) -> None: self.item_parsers = tuple(_get_deserializer(item_type, context) for item_type in self.item_types) @property @@ -313,7 +308,7 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]): type_names = ", ".join(python_type_to_str(item_type) for item_type in self.item_types) return f"Tuple[{type_names}]" - def parse(self, data: JsonType) -> Tuple[Any, ...]: + def parse(self, data: JsonType) -> tuple[Any, ...]: if not isinstance(data, list) or len(data) != len(self.item_parsers): if not isinstance(data, list): raise JsonTypeError( @@ -331,13 +326,13 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]): class UnionDeserializer(Deserializer): "De-serializes a JSON value (of any type) into a Python union type." - member_types: Tuple[type, ...] - member_parsers: Tuple[Deserializer, ...] + member_types: tuple[type, ...] + member_parsers: tuple[Deserializer, ...] - def __init__(self, member_types: Tuple[type, ...]) -> None: + def __init__(self, member_types: tuple[type, ...]) -> None: self.member_types = member_types - def build(self, context: Optional[ModuleType]) -> None: + def build(self, context: ModuleType | None) -> None: self.member_parsers = tuple(_get_deserializer(member_type, context) for member_type in self.member_types) def parse(self, data: JsonType) -> Any: @@ -354,15 +349,15 @@ class UnionDeserializer(Deserializer): raise JsonKeyError(f"type `Union[{type_names}]` could not be instantiated from: {data}") -def get_literal_properties(typ: type) -> Set[str]: +def get_literal_properties(typ: type) -> set[str]: "Returns the names of all properties in a class that are of a literal type." - return set( + return { property_name for property_name, property_type in get_class_properties(typ) if is_type_literal(property_type) - ) + } -def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]: +def get_discriminating_properties(types: tuple[type, ...]) -> set[str]: "Returns a set of properties with literal type that are common across all specified classes." if not types or not all(isinstance(typ, type) for typ in types): @@ -378,15 +373,15 @@ def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]: class TaggedUnionDeserializer(Deserializer): "De-serializes a JSON value with one or more disambiguating properties into a Python union type." - member_types: Tuple[type, ...] - disambiguating_properties: Set[str] - member_parsers: Dict[Tuple[str, Any], Deserializer] + member_types: tuple[type, ...] + disambiguating_properties: set[str] + member_parsers: dict[tuple[str, Any], Deserializer] - def __init__(self, member_types: Tuple[type, ...]) -> None: + def __init__(self, member_types: tuple[type, ...]) -> None: self.member_types = member_types self.disambiguating_properties = get_discriminating_properties(member_types) - def build(self, context: Optional[ModuleType]) -> None: + def build(self, context: ModuleType | None) -> None: self.member_parsers = {} for member_type in self.member_types: for property_name in self.disambiguating_properties: @@ -435,13 +430,13 @@ class TaggedUnionDeserializer(Deserializer): class LiteralDeserializer(Deserializer): "De-serializes a JSON value into a Python literal type." - values: Tuple[Any, ...] + values: tuple[Any, ...] parser: Deserializer - def __init__(self, values: Tuple[Any, ...]) -> None: + def __init__(self, values: tuple[Any, ...]) -> None: self.values = values - def build(self, context: Optional[ModuleType]) -> None: + def build(self, context: ModuleType | None) -> None: literal_type_tuple = tuple(type(value) for value in self.values) literal_type_set = set(literal_type_tuple) if len(literal_type_set) != 1: @@ -464,9 +459,9 @@ class LiteralDeserializer(Deserializer): class EnumDeserializer(Deserializer[E]): "Returns an enumeration instance based on the enumeration value read from a JSON value." - enum_type: Type[E] + enum_type: type[E] - def __init__(self, enum_type: Type[E]) -> None: + def __init__(self, enum_type: type[E]) -> None: self.enum_type = enum_type def parse(self, data: JsonType) -> E: @@ -504,13 +499,13 @@ class FieldDeserializer(abc.ABC, Generic[T, R]): self.parser = parser @abc.abstractmethod - def parse_field(self, data: Dict[str, JsonType]) -> R: ... + def parse_field(self, data: dict[str, JsonType]) -> R: ... class RequiredFieldDeserializer(FieldDeserializer[T, T]): "Deserializes a JSON property into a mandatory Python object field." - def parse_field(self, data: Dict[str, JsonType]) -> T: + def parse_field(self, data: dict[str, JsonType]) -> T: if self.property_name not in data: raise JsonKeyError(f"missing required property `{self.property_name}` from JSON object: {data}") @@ -520,7 +515,7 @@ class RequiredFieldDeserializer(FieldDeserializer[T, T]): class OptionalFieldDeserializer(FieldDeserializer[T, Optional[T]]): "Deserializes a JSON property into an optional Python object field with a default value of `None`." - def parse_field(self, data: Dict[str, JsonType]) -> Optional[T]: + def parse_field(self, data: dict[str, JsonType]) -> T | None: value = data.get(self.property_name) if value is not None: return self.parser.parse(value) @@ -543,7 +538,7 @@ class DefaultFieldDeserializer(FieldDeserializer[T, T]): super().__init__(property_name, field_name, parser) self.default_value = default_value - def parse_field(self, data: Dict[str, JsonType]) -> T: + def parse_field(self, data: dict[str, JsonType]) -> T: value = data.get(self.property_name) if value is not None: return self.parser.parse(value) @@ -566,7 +561,7 @@ class DefaultFactoryFieldDeserializer(FieldDeserializer[T, T]): super().__init__(property_name, field_name, parser) self.default_factory = default_factory - def parse_field(self, data: Dict[str, JsonType]) -> T: + def parse_field(self, data: dict[str, JsonType]) -> T: value = data.get(self.property_name) if value is not None: return self.parser.parse(value) @@ -578,22 +573,22 @@ class ClassDeserializer(Deserializer[T]): "Base class for de-serializing class-like types such as data classes, named tuples and regular classes." class_type: type - property_parsers: List[FieldDeserializer] - property_fields: Set[str] + property_parsers: list[FieldDeserializer] + property_fields: set[str] - def __init__(self, class_type: Type[T]) -> None: + def __init__(self, class_type: type[T]) -> None: self.class_type = class_type - def assign(self, property_parsers: List[FieldDeserializer]) -> None: + def assign(self, property_parsers: list[FieldDeserializer]) -> None: self.property_parsers = property_parsers - self.property_fields = set(property_parser.property_name for property_parser in property_parsers) + self.property_fields = {property_parser.property_name for property_parser in property_parsers} def parse(self, data: JsonType) -> T: if not isinstance(data, dict): type_name = python_type_to_str(self.class_type) raise JsonTypeError(f"`type `{type_name}` expects JSON `object` data but instead received: {data}") - object_data: Dict[str, JsonType] = typing.cast(Dict[str, JsonType], data) + object_data: dict[str, JsonType] = typing.cast(dict[str, JsonType], data) field_values = {} for property_parser in self.property_parsers: @@ -619,8 +614,8 @@ class ClassDeserializer(Deserializer[T]): class NamedTupleDeserializer(ClassDeserializer[NamedTuple]): "De-serializes a named tuple from a JSON `object`." - def build(self, context: Optional[ModuleType]) -> None: - property_parsers: List[FieldDeserializer] = [ + def build(self, context: ModuleType | None) -> None: + property_parsers: list[FieldDeserializer] = [ RequiredFieldDeserializer(field_name, field_name, _get_deserializer(field_type, context)) for field_name, field_type in get_resolved_hints(self.class_type).items() ] @@ -634,13 +629,13 @@ class NamedTupleDeserializer(ClassDeserializer[NamedTuple]): class DataclassDeserializer(ClassDeserializer[T]): "De-serializes a data class from a JSON `object`." - def __init__(self, class_type: Type[T]) -> None: + def __init__(self, class_type: type[T]) -> None: if not dataclasses.is_dataclass(class_type): raise TypeError("expected: data-class type") super().__init__(class_type) # type: ignore[arg-type] - def build(self, context: Optional[ModuleType]) -> None: - property_parsers: List[FieldDeserializer] = [] + def build(self, context: ModuleType | None) -> None: + property_parsers: list[FieldDeserializer] = [] resolved_hints = get_resolved_hints(self.class_type) for field in dataclasses.fields(self.class_type): field_type = resolved_hints[field.name] @@ -651,7 +646,7 @@ class DataclassDeserializer(ClassDeserializer[T]): has_default_factory = field.default_factory is not dataclasses.MISSING if is_optional: - required_type: Type[T] = unwrap_optional_type(field_type) + required_type: type[T] = unwrap_optional_type(field_type) else: required_type = field_type @@ -691,15 +686,15 @@ class FrozenDataclassDeserializer(DataclassDeserializer[T]): class TypedClassDeserializer(ClassDeserializer[T]): "De-serializes a class with type annotations from a JSON `object` by iterating over class properties." - def build(self, context: Optional[ModuleType]) -> None: - property_parsers: List[FieldDeserializer] = [] + def build(self, context: ModuleType | None) -> None: + property_parsers: list[FieldDeserializer] = [] for field_name, field_type in get_resolved_hints(self.class_type).items(): property_name = python_field_to_json_property(field_name, field_type) is_optional = is_type_optional(field_type) if is_optional: - required_type: Type[T] = unwrap_optional_type(field_type) + required_type: type[T] = unwrap_optional_type(field_type) else: required_type = field_type @@ -715,7 +710,7 @@ class TypedClassDeserializer(ClassDeserializer[T]): super().assign(property_parsers) -def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Deserializer: +def create_deserializer(typ: TypeLike, context: ModuleType | None = None) -> Deserializer: """ Creates a de-serializer engine to produce a Python object from an object obtained from a JSON string. @@ -741,15 +736,15 @@ def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) -> return _get_deserializer(typ, context) -_CACHE: Dict[Tuple[str, str], Deserializer] = {} +_CACHE: dict[tuple[str, str], Deserializer] = {} -def _get_deserializer(typ: TypeLike, context: Optional[ModuleType]) -> Deserializer: +def _get_deserializer(typ: TypeLike, context: ModuleType | None) -> Deserializer: "Creates or re-uses a de-serializer engine to parse an object obtained from a JSON string." cache_key = None - if isinstance(typ, (str, typing.ForwardRef)): + if isinstance(typ, str | typing.ForwardRef): if context is None: raise TypeError(f"missing context for evaluating type: {typ}") diff --git a/llama_stack/strong_typing/docstring.py b/llama_stack/strong_typing/docstring.py index 497c9ea82..eaea96198 100644 --- a/llama_stack/strong_typing/docstring.py +++ b/llama_stack/strong_typing/docstring.py @@ -15,17 +15,12 @@ import collections.abc import dataclasses import inspect import re -import sys import types import typing +from collections.abc import Callable from dataclasses import dataclass from io import StringIO -from typing import Any, Callable, Dict, Optional, Protocol, Type, TypeVar - -if sys.version_info >= (3, 10): - from typing import TypeGuard -else: - from typing_extensions import TypeGuard +from typing import Any, Protocol, TypeGuard, TypeVar from .inspection import ( DataclassInstance, @@ -110,14 +105,14 @@ class Docstring: :param returns: The returns declaration extracted from a docstring. """ - short_description: Optional[str] = None - long_description: Optional[str] = None - params: Dict[str, DocstringParam] = dataclasses.field(default_factory=dict) - returns: Optional[DocstringReturns] = None - raises: Dict[str, DocstringRaises] = dataclasses.field(default_factory=dict) + short_description: str | None = None + long_description: str | None = None + params: dict[str, DocstringParam] = dataclasses.field(default_factory=dict) + returns: DocstringReturns | None = None + raises: dict[str, DocstringRaises] = dataclasses.field(default_factory=dict) @property - def full_description(self) -> Optional[str]: + def full_description(self) -> str | None: if self.short_description and self.long_description: return f"{self.short_description}\n\n{self.long_description}" elif self.short_description: @@ -158,18 +153,18 @@ class Docstring: return s -def is_exception(member: object) -> TypeGuard[Type[BaseException]]: +def is_exception(member: object) -> TypeGuard[type[BaseException]]: return isinstance(member, type) and issubclass(member, BaseException) -def get_exceptions(module: types.ModuleType) -> Dict[str, Type[BaseException]]: +def get_exceptions(module: types.ModuleType) -> dict[str, type[BaseException]]: "Returns all exception classes declared in a module." - return {name: class_type for name, class_type in inspect.getmembers(module, is_exception)} + return dict(inspect.getmembers(module, is_exception)) class SupportsDoc(Protocol): - __doc__: Optional[str] + __doc__: str | None def _maybe_unwrap_async_iterator(t): @@ -213,7 +208,7 @@ def parse_type(typ: SupportsDoc) -> Docstring: # assign exception types defining_module = inspect.getmodule(typ) if defining_module: - context: Dict[str, type] = {} + context: dict[str, type] = {} context.update(get_exceptions(builtins)) context.update(get_exceptions(defining_module)) for exc_name, exc in docstring.raises.items(): @@ -262,8 +257,8 @@ def parse_text(text: str) -> Docstring: else: long_description = None - params: Dict[str, DocstringParam] = {} - raises: Dict[str, DocstringRaises] = {} + params: dict[str, DocstringParam] = {} + raises: dict[str, DocstringRaises] = {} returns = None for match in re.finditer(r"(^:.*?)(?=^:|\Z)", meta_chunk, flags=re.DOTALL | re.MULTILINE): chunk = match.group(0) @@ -325,7 +320,7 @@ def has_docstring(typ: SupportsDoc) -> bool: return bool(typ.__doc__) -def get_docstring(typ: SupportsDoc) -> Optional[str]: +def get_docstring(typ: SupportsDoc) -> str | None: if typ.__doc__ is None: return None @@ -348,7 +343,7 @@ def check_docstring(typ: SupportsDoc, docstring: Docstring, strict: bool = False check_function_docstring(typ, docstring, strict) -def check_dataclass_docstring(typ: Type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None: +def check_dataclass_docstring(typ: type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None: """ Verifies the doc-string of a data-class type. diff --git a/llama_stack/strong_typing/inspection.py b/llama_stack/strong_typing/inspection.py index a75a170cf..512702173 100644 --- a/llama_stack/strong_typing/inspection.py +++ b/llama_stack/strong_typing/inspection.py @@ -22,34 +22,19 @@ import sys import types import typing import uuid +from collections.abc import Callable, Iterable from typing import ( + Annotated, Any, - Callable, - Dict, - Iterable, - List, Literal, NamedTuple, - Optional, Protocol, - Set, - Tuple, - Type, + TypeGuard, TypeVar, Union, runtime_checkable, ) -if sys.version_info >= (3, 9): - from typing import Annotated -else: - from typing_extensions import Annotated - -if sys.version_info >= (3, 10): - from typing import TypeGuard -else: - from typing_extensions import TypeGuard - S = TypeVar("S") T = TypeVar("T") K = TypeVar("K") @@ -80,28 +65,20 @@ def _is_type_like(data_type: object) -> bool: return False -if sys.version_info >= (3, 9): - TypeLike = Union[type, types.GenericAlias, typing.ForwardRef, Any] +TypeLike = Union[type, types.GenericAlias, typing.ForwardRef, Any] - def is_type_like( - data_type: object, - ) -> TypeGuard[TypeLike]: - """ - Checks if the object is a type or type-like object (e.g. generic type). - :param data_type: The object to validate. - :returns: True if the object is a type or type-like object. - """ +def is_type_like( + data_type: object, +) -> TypeGuard[TypeLike]: + """ + Checks if the object is a type or type-like object (e.g. generic type). - return _is_type_like(data_type) + :param data_type: The object to validate. + :returns: True if the object is a type or type-like object. + """ -else: - TypeLike = object - - def is_type_like( - data_type: object, - ) -> bool: - return _is_type_like(data_type) + return _is_type_like(data_type) def evaluate_member_type(typ: Any, cls: type) -> Any: @@ -129,20 +106,17 @@ def evaluate_type(typ: Any, module: types.ModuleType) -> Any: # evaluate data-class field whose type annotation is a string return eval(typ, module.__dict__, locals()) if isinstance(typ, typing.ForwardRef): - if sys.version_info >= (3, 9): - return typ._evaluate(module.__dict__, locals(), recursive_guard=frozenset()) - else: - return typ._evaluate(module.__dict__, locals()) + return typ._evaluate(module.__dict__, locals(), recursive_guard=frozenset()) else: return typ @runtime_checkable class DataclassInstance(Protocol): - __dataclass_fields__: typing.ClassVar[Dict[str, dataclasses.Field]] + __dataclass_fields__: typing.ClassVar[dict[str, dataclasses.Field]] -def is_dataclass_type(typ: Any) -> TypeGuard[Type[DataclassInstance]]: +def is_dataclass_type(typ: Any) -> TypeGuard[type[DataclassInstance]]: "True if the argument corresponds to a data class type (but not an instance)." typ = unwrap_annotated_type(typ) @@ -167,14 +141,14 @@ class DataclassField: self.default = default -def dataclass_fields(cls: Type[DataclassInstance]) -> Iterable[DataclassField]: +def dataclass_fields(cls: type[DataclassInstance]) -> Iterable[DataclassField]: "Generates the fields of a data-class resolving forward references." for field in dataclasses.fields(cls): yield DataclassField(field.name, evaluate_member_type(field.type, cls), field.default) -def dataclass_field_by_name(cls: Type[DataclassInstance], name: str) -> DataclassField: +def dataclass_field_by_name(cls: type[DataclassInstance], name: str) -> DataclassField: "Looks up a field in a data-class by its field name." for field in dataclasses.fields(cls): @@ -190,7 +164,7 @@ def is_named_tuple_instance(obj: Any) -> TypeGuard[NamedTuple]: return is_named_tuple_type(type(obj)) -def is_named_tuple_type(typ: Any) -> TypeGuard[Type[NamedTuple]]: +def is_named_tuple_type(typ: Any) -> TypeGuard[type[NamedTuple]]: """ True if the argument corresponds to a named tuple type. @@ -217,26 +191,14 @@ def is_named_tuple_type(typ: Any) -> TypeGuard[Type[NamedTuple]]: return all(isinstance(n, str) for n in f) -if sys.version_info >= (3, 11): +def is_type_enum(typ: object) -> TypeGuard[type[enum.Enum]]: + "True if the specified type is an enumeration type." - def is_type_enum(typ: object) -> TypeGuard[Type[enum.Enum]]: - "True if the specified type is an enumeration type." - - typ = unwrap_annotated_type(typ) - return isinstance(typ, enum.EnumType) - -else: - - def is_type_enum(typ: object) -> TypeGuard[Type[enum.Enum]]: - "True if the specified type is an enumeration type." - - typ = unwrap_annotated_type(typ) - - # use an explicit isinstance(..., type) check to filter out special forms like generics - return isinstance(typ, type) and issubclass(typ, enum.Enum) + typ = unwrap_annotated_type(typ) + return isinstance(typ, enum.EnumType) -def enum_value_types(enum_type: Type[enum.Enum]) -> List[type]: +def enum_value_types(enum_type: type[enum.Enum]) -> list[type]: """ Returns all unique value types of the `enum.Enum` type in definition order. """ @@ -246,8 +208,8 @@ def enum_value_types(enum_type: Type[enum.Enum]) -> List[type]: def extend_enum( - source: Type[enum.Enum], -) -> Callable[[Type[enum.Enum]], Type[enum.Enum]]: + source: type[enum.Enum], +) -> Callable[[type[enum.Enum]], type[enum.Enum]]: """ Creates a new enumeration type extending the set of values in an existing type. @@ -255,13 +217,13 @@ def extend_enum( :returns: A new enumeration type with the extended set of values. """ - def wrap(extend: Type[enum.Enum]) -> Type[enum.Enum]: + def wrap(extend: type[enum.Enum]) -> type[enum.Enum]: # create new enumeration type combining the values from both types - values: Dict[str, Any] = {} + values: dict[str, Any] = {} values.update((e.name, e.value) for e in source) values.update((e.name, e.value) for e in extend) # mypy fails to determine that __name__ is always a string; hence the `ignore` directive. - enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc] + enum_class: type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc] # assign the newly created type to the same module where the extending class is defined enum_class.__module__ = extend.__module__ @@ -273,22 +235,13 @@ def extend_enum( return wrap -if sys.version_info >= (3, 10): +def _is_union_like(typ: object) -> bool: + "True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`." - def _is_union_like(typ: object) -> bool: - "True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`." - - return typing.get_origin(typ) is Union or isinstance(typ, types.UnionType) - -else: - - def _is_union_like(typ: object) -> bool: - "True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`." - - return typing.get_origin(typ) is Union + return typing.get_origin(typ) is Union or isinstance(typ, types.UnionType) -def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[Type[Optional[Any]]]: +def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[type[Any | None]]: """ True if the type annotation corresponds to an optional type (e.g. `Optional[T]` or `Union[T1,T2,None]`). @@ -309,7 +262,7 @@ def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[Type[Option return False -def unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]: +def unwrap_optional_type(typ: type[T | None]) -> type[T]: """ Extracts the inner type of an optional type. @@ -320,7 +273,7 @@ def unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]: return rewrap_annotated_type(_unwrap_optional_type, typ) -def _unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]: +def _unwrap_optional_type(typ: type[T | None]) -> type[T]: "Extracts the type qualified as optional (e.g. returns `T` for `Optional[T]`)." # Optional[T] is represented internally as Union[T, None] @@ -342,7 +295,7 @@ def is_type_union(typ: object) -> bool: return False -def unwrap_union_types(typ: object) -> Tuple[object, ...]: +def unwrap_union_types(typ: object) -> tuple[object, ...]: """ Extracts the inner types of a union type. @@ -354,7 +307,7 @@ def unwrap_union_types(typ: object) -> Tuple[object, ...]: return _unwrap_union_types(typ) -def _unwrap_union_types(typ: object) -> Tuple[object, ...]: +def _unwrap_union_types(typ: object) -> tuple[object, ...]: "Extracts the types in a union (e.g. returns a tuple of types `T1` and `T2` for `Union[T1, T2]`)." if not _is_union_like(typ): @@ -385,7 +338,7 @@ def unwrap_literal_value(typ: object) -> Any: return args[0] -def unwrap_literal_values(typ: object) -> Tuple[Any, ...]: +def unwrap_literal_values(typ: object) -> tuple[Any, ...]: """ Extracts the constant values captured by a literal type. @@ -397,7 +350,7 @@ def unwrap_literal_values(typ: object) -> Tuple[Any, ...]: return typing.get_args(typ) -def unwrap_literal_types(typ: object) -> Tuple[type, ...]: +def unwrap_literal_types(typ: object) -> tuple[type, ...]: """ Extracts the types of the constant values captured by a literal type. @@ -408,14 +361,14 @@ def unwrap_literal_types(typ: object) -> Tuple[type, ...]: return tuple(type(t) for t in unwrap_literal_values(typ)) -def is_generic_list(typ: object) -> TypeGuard[Type[list]]: +def is_generic_list(typ: object) -> TypeGuard[type[list]]: "True if the specified type is a generic list, i.e. `List[T]`." typ = unwrap_annotated_type(typ) return typing.get_origin(typ) is list -def unwrap_generic_list(typ: Type[List[T]]) -> Type[T]: +def unwrap_generic_list(typ: type[list[T]]) -> type[T]: """ Extracts the item type of a list type. @@ -426,21 +379,21 @@ def unwrap_generic_list(typ: Type[List[T]]) -> Type[T]: return rewrap_annotated_type(_unwrap_generic_list, typ) -def _unwrap_generic_list(typ: Type[List[T]]) -> Type[T]: +def _unwrap_generic_list(typ: type[list[T]]) -> type[T]: "Extracts the item type of a list type (e.g. returns `T` for `List[T]`)." (list_type,) = typing.get_args(typ) # unpack single tuple element return list_type # type: ignore[no-any-return] -def is_generic_set(typ: object) -> TypeGuard[Type[set]]: +def is_generic_set(typ: object) -> TypeGuard[type[set]]: "True if the specified type is a generic set, i.e. `Set[T]`." typ = unwrap_annotated_type(typ) return typing.get_origin(typ) is set -def unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]: +def unwrap_generic_set(typ: type[set[T]]) -> type[T]: """ Extracts the item type of a set type. @@ -451,21 +404,21 @@ def unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]: return rewrap_annotated_type(_unwrap_generic_set, typ) -def _unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]: +def _unwrap_generic_set(typ: type[set[T]]) -> type[T]: "Extracts the item type of a set type (e.g. returns `T` for `Set[T]`)." (set_type,) = typing.get_args(typ) # unpack single tuple element return set_type # type: ignore[no-any-return] -def is_generic_dict(typ: object) -> TypeGuard[Type[dict]]: +def is_generic_dict(typ: object) -> TypeGuard[type[dict]]: "True if the specified type is a generic dictionary, i.e. `Dict[KeyType, ValueType]`." typ = unwrap_annotated_type(typ) return typing.get_origin(typ) is dict -def unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]: +def unwrap_generic_dict(typ: type[dict[K, V]]) -> tuple[type[K], type[V]]: """ Extracts the key and value types of a dictionary type as a tuple. @@ -476,7 +429,7 @@ def unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]: return _unwrap_generic_dict(unwrap_annotated_type(typ)) -def _unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]: +def _unwrap_generic_dict(typ: type[dict[K, V]]) -> tuple[type[K], type[V]]: "Extracts the key and value types of a dict type (e.g. returns (`K`, `V`) for `Dict[K, V]`)." key_type, value_type = typing.get_args(typ) @@ -489,7 +442,7 @@ def is_type_annotated(typ: TypeLike) -> bool: return getattr(typ, "__metadata__", None) is not None -def get_annotation(data_type: TypeLike, annotation_type: Type[T]) -> Optional[T]: +def get_annotation(data_type: TypeLike, annotation_type: type[T]) -> T | None: """ Returns the first annotation on a data type that matches the expected annotation type. @@ -518,7 +471,7 @@ def unwrap_annotated_type(typ: T) -> T: return typ -def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S]) -> Type[T]: +def rewrap_annotated_type(transform: Callable[[type[S]], type[T]], typ: type[S]) -> type[T]: """ Un-boxes, transforms and re-boxes an optionally annotated type. @@ -542,7 +495,7 @@ def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S]) return transformed_type -def get_module_classes(module: types.ModuleType) -> List[type]: +def get_module_classes(module: types.ModuleType) -> list[type]: "Returns all classes declared directly in a module." def is_class_member(member: object) -> TypeGuard[type]: @@ -551,18 +504,11 @@ def get_module_classes(module: types.ModuleType) -> List[type]: return [class_type for _, class_type in inspect.getmembers(module, is_class_member)] -if sys.version_info >= (3, 9): - - def get_resolved_hints(typ: type) -> Dict[str, type]: - return typing.get_type_hints(typ, include_extras=True) - -else: - - def get_resolved_hints(typ: type) -> Dict[str, type]: - return typing.get_type_hints(typ) +def get_resolved_hints(typ: type) -> dict[str, type]: + return typing.get_type_hints(typ, include_extras=True) -def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]: +def get_class_properties(typ: type) -> Iterable[tuple[str, type | str]]: "Returns all properties of a class." if is_dataclass_type(typ): @@ -572,7 +518,7 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]: return resolved_hints.items() -def get_class_property(typ: type, name: str) -> Optional[type | str]: +def get_class_property(typ: type, name: str) -> type | str | None: "Looks up the annotated type of a property in a class by its property name." for property_name, property_type in get_class_properties(typ): @@ -586,7 +532,7 @@ class _ROOT: pass -def get_referenced_types(typ: TypeLike, module: Optional[types.ModuleType] = None) -> Set[type]: +def get_referenced_types(typ: TypeLike, module: types.ModuleType | None = None) -> set[type]: """ Extracts types directly or indirectly referenced by this type. @@ -610,10 +556,10 @@ class TypeCollector: :param graph: The type dependency graph, linking types to types they depend on. """ - graph: Dict[type, Set[type]] + graph: dict[type, set[type]] @property - def references(self) -> Set[type]: + def references(self) -> set[type]: "Types collected by the type collector." dependencies = set() @@ -638,8 +584,8 @@ class TypeCollector: def run( self, typ: TypeLike, - cls: Type[DataclassInstance], - module: Optional[types.ModuleType], + cls: type[DataclassInstance], + module: types.ModuleType | None, ) -> None: """ Extracts types indirectly referenced by this type. @@ -702,26 +648,17 @@ class TypeCollector: for field in dataclass_fields(typ): self.run(field.type, typ, context) else: - for field_name, field_type in get_resolved_hints(typ).items(): + for _field_name, field_type in get_resolved_hints(typ).items(): self.run(field_type, typ, context) return raise TypeError(f"expected: type-like; got: {typ}") -if sys.version_info >= (3, 10): +def get_signature(fn: Callable[..., Any]) -> inspect.Signature: + "Extracts the signature of a function." - def get_signature(fn: Callable[..., Any]) -> inspect.Signature: - "Extracts the signature of a function." - - return inspect.signature(fn, eval_str=True) - -else: - - def get_signature(fn: Callable[..., Any]) -> inspect.Signature: - "Extracts the signature of a function." - - return inspect.signature(fn) + return inspect.signature(fn, eval_str=True) def is_reserved_property(name: str) -> bool: @@ -756,51 +693,20 @@ def create_module(name: str) -> types.ModuleType: return module -if sys.version_info >= (3, 10): +def create_data_type(class_name: str, fields: list[tuple[str, type]]) -> type: + """ + Creates a new data-class type dynamically. - def create_data_type(class_name: str, fields: List[Tuple[str, type]]) -> type: - """ - Creates a new data-class type dynamically. + :param class_name: The name of new data-class type. + :param fields: A list of fields (and their type) that the new data-class type is expected to have. + :returns: The newly created data-class type. + """ - :param class_name: The name of new data-class type. - :param fields: A list of fields (and their type) that the new data-class type is expected to have. - :returns: The newly created data-class type. - """ - - # has the `slots` parameter - return dataclasses.make_dataclass(class_name, fields, slots=True) - -else: - - def create_data_type(class_name: str, fields: List[Tuple[str, type]]) -> type: - """ - Creates a new data-class type dynamically. - - :param class_name: The name of new data-class type. - :param fields: A list of fields (and their type) that the new data-class type is expected to have. - :returns: The newly created data-class type. - """ - - cls = dataclasses.make_dataclass(class_name, fields) - - cls_dict = dict(cls.__dict__) - field_names = tuple(field.name for field in dataclasses.fields(cls)) - - cls_dict["__slots__"] = field_names - - for field_name in field_names: - cls_dict.pop(field_name, None) - cls_dict.pop("__dict__", None) - - qualname = getattr(cls, "__qualname__", None) - cls = type(cls)(cls.__name__, (), cls_dict) - if qualname is not None: - cls.__qualname__ = qualname - - return cls + # has the `slots` parameter + return dataclasses.make_dataclass(class_name, fields, slots=True) -def create_object(typ: Type[T]) -> T: +def create_object(typ: type[T]) -> T: "Creates an instance of a type." if issubclass(typ, Exception): @@ -811,11 +717,7 @@ def create_object(typ: Type[T]) -> T: return object.__new__(typ) -if sys.version_info >= (3, 9): - TypeOrGeneric = Union[type, types.GenericAlias] - -else: - TypeOrGeneric = object +TypeOrGeneric = Union[type, types.GenericAlias] def is_generic_instance(obj: Any, typ: TypeLike) -> bool: @@ -885,7 +787,7 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool: class RecursiveChecker: - _pred: Optional[Callable[[type, Any], bool]] + _pred: Callable[[type, Any], bool] | None def __init__(self, pred: Callable[[type, Any], bool]) -> None: """ @@ -997,9 +899,9 @@ def check_recursive( obj: object, /, *, - pred: Optional[Callable[[type, Any], bool]] = None, - type_pred: Optional[Callable[[type], bool]] = None, - value_pred: Optional[Callable[[Any], bool]] = None, + pred: Callable[[type, Any], bool] | None = None, + type_pred: Callable[[type], bool] | None = None, + value_pred: Callable[[Any], bool] | None = None, ) -> bool: """ Checks if a predicate applies to all nested member properties of an object recursively. @@ -1015,7 +917,7 @@ def check_recursive( if pred is not None: raise TypeError("filter predicate not permitted when type and value predicates are present") - type_p: Callable[[Type[T]], bool] = type_pred + type_p: Callable[[type[T]], bool] = type_pred value_p: Callable[[T], bool] = value_pred pred = lambda typ, obj: not type_p(typ) or value_p(obj) # noqa: E731 diff --git a/llama_stack/strong_typing/mapping.py b/llama_stack/strong_typing/mapping.py index 408375a9f..d6c1a3172 100644 --- a/llama_stack/strong_typing/mapping.py +++ b/llama_stack/strong_typing/mapping.py @@ -11,13 +11,12 @@ Type-safe data interchange for Python data classes. """ import keyword -from typing import Optional from .auxiliary import Alias from .inspection import get_annotation -def python_field_to_json_property(python_id: str, python_type: Optional[object] = None) -> str: +def python_field_to_json_property(python_id: str, python_type: object | None = None) -> str: """ Map a Python field identifier to a JSON property name. diff --git a/llama_stack/strong_typing/name.py b/llama_stack/strong_typing/name.py index a1a2ae5f1..00cdc2ae2 100644 --- a/llama_stack/strong_typing/name.py +++ b/llama_stack/strong_typing/name.py @@ -11,7 +11,7 @@ Type-safe data interchange for Python data classes. """ import typing -from typing import Any, Literal, Optional, Tuple, Union +from typing import Any, Literal, Union from .auxiliary import _auxiliary_types from .inspection import ( @@ -39,7 +39,7 @@ class TypeFormatter: def __init__(self, use_union_operator: bool = False) -> None: self.use_union_operator = use_union_operator - def union_to_str(self, data_type_args: Tuple[TypeLike, ...]) -> str: + def union_to_str(self, data_type_args: tuple[TypeLike, ...]) -> str: if self.use_union_operator: return " | ".join(self.python_type_to_str(t) for t in data_type_args) else: @@ -100,7 +100,7 @@ class TypeFormatter: metadata = getattr(data_type, "__metadata__", None) if metadata is not None: # type is Annotated[T, ...] - metatuple: Tuple[Any, ...] = metadata + metatuple: tuple[Any, ...] = metadata arg = typing.get_args(data_type)[0] # check for auxiliary types with user-defined annotations @@ -110,7 +110,7 @@ class TypeFormatter: if arg is not auxiliary_arg: continue - auxiliary_metatuple: Optional[Tuple[Any, ...]] = getattr(auxiliary_type, "__metadata__", None) + auxiliary_metatuple: tuple[Any, ...] | None = getattr(auxiliary_type, "__metadata__", None) if auxiliary_metatuple is None: continue diff --git a/llama_stack/strong_typing/schema.py b/llama_stack/strong_typing/schema.py index 82baddc86..7e0f29467 100644 --- a/llama_stack/strong_typing/schema.py +++ b/llama_stack/strong_typing/schema.py @@ -21,24 +21,19 @@ import json import types import typing import uuid +from collections.abc import Callable from copy import deepcopy from typing import ( + Annotated, Any, - Callable, ClassVar, - Dict, - List, Literal, - Optional, - Tuple, - Type, TypeVar, Union, overload, ) import jsonschema -from typing_extensions import Annotated from . import docstring from .auxiliary import ( @@ -71,7 +66,7 @@ OBJECT_ENUM_EXPANSION_LIMIT = 4 T = TypeVar("T") -def get_class_docstrings(data_type: type) -> Tuple[Optional[str], Optional[str]]: +def get_class_docstrings(data_type: type) -> tuple[str | None, str | None]: docstr = docstring.parse_type(data_type) # check if class has a doc-string other than the auto-generated string assigned by @dataclass @@ -82,8 +77,8 @@ def get_class_docstrings(data_type: type) -> Tuple[Optional[str], Optional[str]] def get_class_property_docstrings( - data_type: type, transform_fun: Optional[Callable[[type, str, str], str]] = None -) -> Dict[str, str]: + data_type: type, transform_fun: Callable[[type, str, str], str] | None = None +) -> dict[str, str]: """ Extracts the documentation strings associated with the properties of a composite type. @@ -120,7 +115,7 @@ def docstring_to_schema(data_type: type) -> Schema: return schema -def id_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> str: +def id_from_ref(data_type: typing.ForwardRef | str | type) -> str: "Extracts the name of a possibly forward-referenced type." if isinstance(data_type, typing.ForwardRef): @@ -132,7 +127,7 @@ def id_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> str: return data_type.__name__ -def type_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> Tuple[str, type]: +def type_from_ref(data_type: typing.ForwardRef | str | type) -> tuple[str, type]: "Creates a type from a forward reference." if isinstance(data_type, typing.ForwardRef): @@ -148,16 +143,16 @@ def type_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> Tuple[str, @dataclasses.dataclass class TypeCatalogEntry: - schema: Optional[Schema] + schema: Schema | None identifier: str - examples: Optional[JsonType] = None + examples: JsonType | None = None class TypeCatalog: "Maintains an association of well-known Python types to their JSON schema." - _by_type: Dict[TypeLike, TypeCatalogEntry] - _by_name: Dict[str, TypeCatalogEntry] + _by_type: dict[TypeLike, TypeCatalogEntry] + _by_name: dict[str, TypeCatalogEntry] def __init__(self) -> None: self._by_type = {} @@ -174,9 +169,9 @@ class TypeCatalog: def add( self, data_type: TypeLike, - schema: Optional[Schema], + schema: Schema | None, identifier: str, - examples: Optional[List[JsonType]] = None, + examples: list[JsonType] | None = None, ) -> None: if isinstance(data_type, typing.ForwardRef): raise TypeError("forward references cannot be used to register a type") @@ -202,17 +197,17 @@ class SchemaOptions: definitions_path: str = "#/definitions/" use_descriptions: bool = True use_examples: bool = True - property_description_fun: Optional[Callable[[type, str, str], str]] = None + property_description_fun: Callable[[type, str, str], str] | None = None class JsonSchemaGenerator: "Creates a JSON schema with user-defined type definitions." type_catalog: ClassVar[TypeCatalog] = TypeCatalog() - types_used: Dict[str, TypeLike] + types_used: dict[str, TypeLike] options: SchemaOptions - def __init__(self, options: Optional[SchemaOptions] = None): + def __init__(self, options: SchemaOptions | None = None): if options is None: self.options = SchemaOptions() else: @@ -244,13 +239,13 @@ class JsonSchemaGenerator: def _(self, arg: MaxLength) -> Schema: return {"maxLength": arg.value} - def _with_metadata(self, type_schema: Schema, metadata: Optional[Tuple[Any, ...]]) -> Schema: + def _with_metadata(self, type_schema: Schema, metadata: tuple[Any, ...] | None) -> Schema: if metadata: for m in metadata: type_schema.update(self._metadata_to_schema(m)) return type_schema - def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: Optional[dict] = None) -> Optional[Schema]: + def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: dict | None = None) -> Schema | None: """ Returns the JSON schema associated with a simple, unrestricted type. @@ -314,7 +309,7 @@ class JsonSchemaGenerator: self, data_type: TypeLike, force_expand: bool = False, - json_schema_extra: Optional[dict] = None, + json_schema_extra: dict | None = None, ) -> Schema: common_info = {} if json_schema_extra and "deprecated" in json_schema_extra: @@ -325,7 +320,7 @@ class JsonSchemaGenerator: self, data_type: TypeLike, force_expand: bool = False, - json_schema_extra: Optional[dict] = None, + json_schema_extra: dict | None = None, ) -> Schema: """ Returns the JSON schema associated with a type. @@ -381,7 +376,7 @@ class JsonSchemaGenerator: return {"$ref": f"{self.options.definitions_path}{identifier}"} if is_type_enum(typ): - enum_type: Type[enum.Enum] = typ + enum_type: type[enum.Enum] = typ value_types = enum_value_types(enum_type) if len(value_types) != 1: raise ValueError( @@ -496,8 +491,8 @@ class JsonSchemaGenerator: members = dict(inspect.getmembers(typ, lambda a: not inspect.isroutine(a))) property_docstrings = get_class_property_docstrings(typ, self.options.property_description_fun) - properties: Dict[str, Schema] = {} - required: List[str] = [] + properties: dict[str, Schema] = {} + required: list[str] = [] for property_name, property_type in get_class_properties(typ): # rename property if an alias name is specified alias = get_annotation(property_type, Alias) @@ -530,16 +525,7 @@ class JsonSchemaGenerator: # check if value can be directly represented in JSON if isinstance( def_value, - ( - bool, - int, - float, - str, - enum.Enum, - datetime.datetime, - datetime.date, - datetime.time, - ), + bool | int | float | str | enum.Enum | datetime.datetime | datetime.date | datetime.time, ): property_def["default"] = object_to_json(def_value) @@ -587,7 +573,7 @@ class JsonSchemaGenerator: return type_schema - def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> Tuple[Schema, Dict[str, Schema]]: + def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> tuple[Schema, dict[str, Schema]]: """ Returns the JSON schema associated with a type and any nested types. @@ -604,7 +590,7 @@ class JsonSchemaGenerator: try: type_schema = self.type_to_schema(data_type, force_expand=force_expand) - types_defined: Dict[str, Schema] = {} + types_defined: dict[str, Schema] = {} while len(self.types_used) > len(types_defined): # make a snapshot copy; original collection is going to be modified types_undefined = { @@ -635,7 +621,7 @@ class Validator(enum.Enum): def classdef_to_schema( data_type: TypeLike, - options: Optional[SchemaOptions] = None, + options: SchemaOptions | None = None, validator: Validator = Validator.Latest, ) -> Schema: """ @@ -689,7 +675,7 @@ def print_schema(data_type: type) -> None: print(json.dumps(s, indent=4)) -def get_schema_identifier(data_type: type) -> Optional[str]: +def get_schema_identifier(data_type: type) -> str | None: if data_type in JsonSchemaGenerator.type_catalog: return JsonSchemaGenerator.type_catalog.get(data_type).identifier else: @@ -698,9 +684,9 @@ def get_schema_identifier(data_type: type) -> Optional[str]: def register_schema( data_type: T, - schema: Optional[Schema] = None, - name: Optional[str] = None, - examples: Optional[List[JsonType]] = None, + schema: Schema | None = None, + name: str | None = None, + examples: list[JsonType] | None = None, ) -> T: """ Associates a type with a JSON schema definition. @@ -721,22 +707,22 @@ def register_schema( @overload -def json_schema_type(cls: Type[T], /) -> Type[T]: ... +def json_schema_type(cls: type[T], /) -> type[T]: ... @overload -def json_schema_type(cls: None, *, schema: Optional[Schema] = None) -> Callable[[Type[T]], Type[T]]: ... +def json_schema_type(cls: None, *, schema: Schema | None = None) -> Callable[[type[T]], type[T]]: ... def json_schema_type( - cls: Optional[Type[T]] = None, + cls: type[T] | None = None, *, - schema: Optional[Schema] = None, - examples: Optional[List[JsonType]] = None, -) -> Union[Type[T], Callable[[Type[T]], Type[T]]]: + schema: Schema | None = None, + examples: list[JsonType] | None = None, +) -> type[T] | Callable[[type[T]], type[T]]: """Decorator to add user-defined schema definition to a class.""" - def wrap(cls: Type[T]) -> Type[T]: + def wrap(cls: type[T]) -> type[T]: return register_schema(cls, schema, examples=examples) # see if decorator is used as @json_schema_type or @json_schema_type() diff --git a/llama_stack/strong_typing/serialization.py b/llama_stack/strong_typing/serialization.py index c00a0aad5..3e34945ad 100644 --- a/llama_stack/strong_typing/serialization.py +++ b/llama_stack/strong_typing/serialization.py @@ -14,7 +14,7 @@ import inspect import json import sys from types import ModuleType -from typing import Any, Optional, TextIO, TypeVar +from typing import Any, TextIO, TypeVar from .core import JsonType from .deserializer import create_deserializer @@ -42,7 +42,7 @@ def object_to_json(obj: Any) -> JsonType: return generator.generate(obj) -def json_to_object(typ: TypeLike, data: JsonType, *, context: Optional[ModuleType] = None) -> object: +def json_to_object(typ: TypeLike, data: JsonType, *, context: ModuleType | None = None) -> object: """ Creates an object from a representation that has been de-serialized from JSON. diff --git a/llama_stack/strong_typing/serializer.py b/llama_stack/strong_typing/serializer.py index 17848c14b..7d827f73c 100644 --- a/llama_stack/strong_typing/serializer.py +++ b/llama_stack/strong_typing/serializer.py @@ -20,19 +20,13 @@ import ipaddress import sys import typing import uuid +from collections.abc import Callable from types import FunctionType, MethodType, ModuleType from typing import ( Any, - Callable, - Dict, Generic, - List, Literal, NamedTuple, - Optional, - Set, - Tuple, - Type, TypeVar, Union, ) @@ -133,7 +127,7 @@ class IPv6Serializer(Serializer[ipaddress.IPv6Address]): class EnumSerializer(Serializer[enum.Enum]): - def generate(self, obj: enum.Enum) -> Union[int, str]: + def generate(self, obj: enum.Enum) -> int | str: value = obj.value if isinstance(value, int): return value @@ -141,12 +135,12 @@ class EnumSerializer(Serializer[enum.Enum]): class UntypedListSerializer(Serializer[list]): - def generate(self, obj: list) -> List[JsonType]: + def generate(self, obj: list) -> list[JsonType]: return [object_to_json(item) for item in obj] class UntypedDictSerializer(Serializer[dict]): - def generate(self, obj: dict) -> Dict[str, JsonType]: + def generate(self, obj: dict) -> dict[str, JsonType]: if obj and isinstance(next(iter(obj.keys())), enum.Enum): iterator = ((key.value, object_to_json(value)) for key, value in obj.items()) else: @@ -155,41 +149,41 @@ class UntypedDictSerializer(Serializer[dict]): class UntypedSetSerializer(Serializer[set]): - def generate(self, obj: set) -> List[JsonType]: + def generate(self, obj: set) -> list[JsonType]: return [object_to_json(item) for item in obj] class UntypedTupleSerializer(Serializer[tuple]): - def generate(self, obj: tuple) -> List[JsonType]: + def generate(self, obj: tuple) -> list[JsonType]: return [object_to_json(item) for item in obj] class TypedCollectionSerializer(Serializer, Generic[T]): generator: Serializer[T] - def __init__(self, item_type: Type[T], context: Optional[ModuleType]) -> None: + def __init__(self, item_type: type[T], context: ModuleType | None) -> None: self.generator = _get_serializer(item_type, context) class TypedListSerializer(TypedCollectionSerializer[T]): - def generate(self, obj: List[T]) -> List[JsonType]: + def generate(self, obj: list[T]) -> list[JsonType]: return [self.generator.generate(item) for item in obj] class TypedStringDictSerializer(TypedCollectionSerializer[T]): - def __init__(self, value_type: Type[T], context: Optional[ModuleType]) -> None: + def __init__(self, value_type: type[T], context: ModuleType | None) -> None: super().__init__(value_type, context) - def generate(self, obj: Dict[str, T]) -> Dict[str, JsonType]: + def generate(self, obj: dict[str, T]) -> dict[str, JsonType]: return {key: self.generator.generate(value) for key, value in obj.items()} class TypedEnumDictSerializer(TypedCollectionSerializer[T]): def __init__( self, - key_type: Type[enum.Enum], - value_type: Type[T], - context: Optional[ModuleType], + key_type: type[enum.Enum], + value_type: type[T], + context: ModuleType | None, ) -> None: super().__init__(value_type, context) @@ -203,22 +197,22 @@ class TypedEnumDictSerializer(TypedCollectionSerializer[T]): if value_type is not str: raise JsonTypeError("invalid enumeration key type, expected `enum.Enum` with string values") - def generate(self, obj: Dict[enum.Enum, T]) -> Dict[str, JsonType]: + def generate(self, obj: dict[enum.Enum, T]) -> dict[str, JsonType]: return {key.value: self.generator.generate(value) for key, value in obj.items()} class TypedSetSerializer(TypedCollectionSerializer[T]): - def generate(self, obj: Set[T]) -> JsonType: + def generate(self, obj: set[T]) -> JsonType: return [self.generator.generate(item) for item in obj] class TypedTupleSerializer(Serializer[tuple]): - item_generators: Tuple[Serializer, ...] + item_generators: tuple[Serializer, ...] - def __init__(self, item_types: Tuple[type, ...], context: Optional[ModuleType]) -> None: + def __init__(self, item_types: tuple[type, ...], context: ModuleType | None) -> None: self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types) - def generate(self, obj: tuple) -> List[JsonType]: + def generate(self, obj: tuple) -> list[JsonType]: return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj, strict=False)] @@ -250,16 +244,16 @@ class FieldSerializer(Generic[T]): self.property_name = property_name self.generator = generator - def generate_field(self, obj: object, object_dict: Dict[str, JsonType]) -> None: + def generate_field(self, obj: object, object_dict: dict[str, JsonType]) -> None: value = getattr(obj, self.field_name) if value is not None: object_dict[self.property_name] = self.generator.generate(value) class TypedClassSerializer(Serializer[T]): - property_generators: List[FieldSerializer] + property_generators: list[FieldSerializer] - def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None: + def __init__(self, class_type: type[T], context: ModuleType | None) -> None: self.property_generators = [ FieldSerializer( field_name, @@ -269,8 +263,8 @@ class TypedClassSerializer(Serializer[T]): for field_name, field_type in get_class_properties(class_type) ] - def generate(self, obj: T) -> Dict[str, JsonType]: - object_dict: Dict[str, JsonType] = {} + def generate(self, obj: T) -> dict[str, JsonType]: + object_dict: dict[str, JsonType] = {} for property_generator in self.property_generators: property_generator.generate_field(obj, object_dict) @@ -278,12 +272,12 @@ class TypedClassSerializer(Serializer[T]): class TypedNamedTupleSerializer(TypedClassSerializer[NamedTuple]): - def __init__(self, class_type: Type[NamedTuple], context: Optional[ModuleType]) -> None: + def __init__(self, class_type: type[NamedTuple], context: ModuleType | None) -> None: super().__init__(class_type, context) class DataclassSerializer(TypedClassSerializer[T]): - def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None: + def __init__(self, class_type: type[T], context: ModuleType | None) -> None: super().__init__(class_type, context) @@ -295,7 +289,7 @@ class UnionSerializer(Serializer): class LiteralSerializer(Serializer): generator: Serializer - def __init__(self, values: Tuple[Any, ...], context: Optional[ModuleType]) -> None: + def __init__(self, values: tuple[Any, ...], context: ModuleType | None) -> None: literal_type_tuple = tuple(type(value) for value in values) literal_type_set = set(literal_type_tuple) if len(literal_type_set) != 1: @@ -312,12 +306,12 @@ class LiteralSerializer(Serializer): class UntypedNamedTupleSerializer(Serializer): - fields: Dict[str, str] + fields: dict[str, str] - def __init__(self, class_type: Type[NamedTuple]) -> None: + def __init__(self, class_type: type[NamedTuple]) -> None: # named tuples are also instances of tuple self.fields = {} - field_names: Tuple[str, ...] = class_type._fields + field_names: tuple[str, ...] = class_type._fields for field_name in field_names: self.fields[field_name] = python_field_to_json_property(field_name) @@ -351,7 +345,7 @@ class UntypedClassSerializer(Serializer): return object_dict -def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Serializer: +def create_serializer(typ: TypeLike, context: ModuleType | None = None) -> Serializer: """ Creates a serializer engine to produce an object that can be directly converted into a JSON string. @@ -376,8 +370,8 @@ def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Se return _get_serializer(typ, context) -def _get_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer: - if isinstance(typ, (str, typing.ForwardRef)): +def _get_serializer(typ: TypeLike, context: ModuleType | None) -> Serializer: + if isinstance(typ, str | typing.ForwardRef): if context is None: raise TypeError(f"missing context for evaluating type: {typ}") @@ -390,13 +384,13 @@ def _get_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer: return _create_serializer(typ, context) -@functools.lru_cache(maxsize=None) +@functools.cache def _fetch_serializer(typ: type) -> Serializer: context = sys.modules[typ.__module__] return _create_serializer(typ, context) -def _create_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer: +def _create_serializer(typ: TypeLike, context: ModuleType | None) -> Serializer: # check for well-known types if typ is type(None): return NoneSerializer() diff --git a/llama_stack/strong_typing/slots.py b/llama_stack/strong_typing/slots.py index c1a3293d8..772834140 100644 --- a/llama_stack/strong_typing/slots.py +++ b/llama_stack/strong_typing/slots.py @@ -4,18 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Tuple, Type, TypeVar +from typing import Any, TypeVar T = TypeVar("T") class SlotsMeta(type): - def __new__(cls: Type[T], name: str, bases: Tuple[type, ...], ns: Dict[str, Any]) -> T: + def __new__(cls: type[T], name: str, bases: tuple[type, ...], ns: dict[str, Any]) -> T: # caller may have already provided slots, in which case just retain them and keep going - slots: Tuple[str, ...] = ns.get("__slots__", ()) + slots: tuple[str, ...] = ns.get("__slots__", ()) # add fields with type annotations to slots - annotations: Dict[str, Any] = ns.get("__annotations__", {}) + annotations: dict[str, Any] = ns.get("__annotations__", {}) members = tuple(member for member in annotations.keys() if member not in slots) # assign slots diff --git a/llama_stack/strong_typing/topological.py b/llama_stack/strong_typing/topological.py index 28bf4bd0f..9502a5887 100644 --- a/llama_stack/strong_typing/topological.py +++ b/llama_stack/strong_typing/topological.py @@ -10,14 +10,15 @@ Type-safe data interchange for Python data classes. :see: https://github.com/hunyadi/strong_typing """ -from typing import Callable, Dict, Iterable, List, Optional, Set, TypeVar +from collections.abc import Callable, Iterable +from typing import TypeVar from .inspection import TypeCollector T = TypeVar("T") -def topological_sort(graph: Dict[T, Set[T]]) -> List[T]: +def topological_sort(graph: dict[T, set[T]]) -> list[T]: """ Performs a topological sort of a graph. @@ -29,9 +30,9 @@ def topological_sort(graph: Dict[T, Set[T]]) -> List[T]: """ # empty list that will contain the sorted nodes (in reverse order) - ordered: List[T] = [] + ordered: list[T] = [] - seen: Dict[T, bool] = {} + seen: dict[T, bool] = {} def _visit(n: T) -> None: status = seen.get(n) @@ -57,8 +58,8 @@ def topological_sort(graph: Dict[T, Set[T]]) -> List[T]: def type_topological_sort( types: Iterable[type], - dependency_fn: Optional[Callable[[type], Iterable[type]]] = None, -) -> List[type]: + dependency_fn: Callable[[type], Iterable[type]] | None = None, +) -> list[type]: """ Performs a topological sort of a list of types. @@ -78,7 +79,7 @@ def type_topological_sort( graph = collector.graph if dependency_fn: - new_types: Set[type] = set() + new_types: set[type] = set() for source_type, references in graph.items(): dependent_types = dependency_fn(source_type) references.update(dependent_types) diff --git a/llama_stack/templates/ollama/__init__.py b/llama_stack/templates/ollama/__init__.py new file mode 100644 index 000000000..3a2c40f27 --- /dev/null +++ b/llama_stack/templates/ollama/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .ollama import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml new file mode 100644 index 000000000..cbf4281a2 --- /dev/null +++ b/llama_stack/templates/ollama/build.yaml @@ -0,0 +1,39 @@ +version: 2 +distribution_spec: + description: Use (an external) Ollama server for running LLM inference + providers: + inference: + - remote::ollama + vector_io: + - inline::faiss + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + eval: + - inline::meta-reference + datasetio: + - remote::huggingface + - inline::localfs + scoring: + - inline::basic + - inline::llm-as-judge + - inline::braintrust + files: + - inline::localfs + post_training: + - inline::huggingface + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::rag-runtime + - remote::model-context-protocol + - remote::wolfram-alpha +image_type: conda +additional_pip_packages: +- aiosqlite +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/ollama/doc_template.md b/llama_stack/templates/ollama/doc_template.md new file mode 100644 index 000000000..83f73bdc0 --- /dev/null +++ b/llama_stack/templates/ollama/doc_template.md @@ -0,0 +1,168 @@ +# Ollama Distribution + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }} {{ model.doc_string }}` +{% endfor %} +{% endif %} + +## Prerequisites + +### Ollama Server + +This distribution requires an external Ollama server to be running. You can install and run Ollama by following these steps: + +1. **Install Ollama**: Download and install Ollama from [https://ollama.ai/](https://ollama.ai/) + +2. **Start the Ollama server**: + ```bash + ollama serve + ``` + By default, Ollama serves on `http://127.0.0.1:11434` + +3. **Pull the required models**: + ```bash + # Pull the inference model + ollama pull meta-llama/Llama-3.2-3B-Instruct + + # Pull the embedding model + ollama pull all-minilm:latest + + # (Optional) Pull the safety model for run-with-safety.yaml + ollama pull meta-llama/Llama-Guard-3-1B + ``` + +## Supported Services + +### Inference: Ollama +Uses an external Ollama server for running LLM inference. The server should be accessible at the URL specified in the `OLLAMA_URL` environment variable. + +### Vector IO: FAISS +Provides vector storage capabilities using FAISS for embeddings and similarity search operations. + +### Safety: Llama Guard (Optional) +When using the `run-with-safety.yaml` configuration, provides safety checks using Llama Guard models running on the Ollama server. + +### Agents: Meta Reference +Provides agent execution capabilities using the meta-reference implementation. + +### Post-Training: Hugging Face +Supports model fine-tuning using Hugging Face integration. + +### Tool Runtime +Supports various external tools including: +- Brave Search +- Tavily Search +- RAG Runtime +- Model Context Protocol +- Wolfram Alpha + +## Running Llama Stack with Ollama + +You can do this via Conda or venv (build code), or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=8321 +docker run \ + -it \ + --pull always \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-{{ name }} \ + --config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env OLLAMA_URL=$OLLAMA_URL \ + --env INFERENCE_MODEL=$INFERENCE_MODEL +``` + +### Via Conda + +```bash +llama stack build --template ollama --image-type conda +llama stack run ./run.yaml \ + --port 8321 \ + --env OLLAMA_URL=$OLLAMA_URL \ + --env INFERENCE_MODEL=$INFERENCE_MODEL +``` + +### Via venv + +If you've set up your local development environment, you can also build the image using your local virtual environment. + +```bash +llama stack build --template ollama --image-type venv +llama stack run ./run.yaml \ + --port 8321 \ + --env OLLAMA_URL=$OLLAMA_URL \ + --env INFERENCE_MODEL=$INFERENCE_MODEL +``` + +### Running with Safety + +To enable safety checks, use the `run-with-safety.yaml` configuration: + +```bash +llama stack run ./run-with-safety.yaml \ + --port 8321 \ + --env OLLAMA_URL=$OLLAMA_URL \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env SAFETY_MODEL=$SAFETY_MODEL +``` + +## Example Usage + +Once your Llama Stack server is running with Ollama, you can interact with it using the Llama Stack client: + +```python +from llama_stack_client import LlamaStackClient + +client = LlamaStackClient(base_url="http://localhost:8321") + +# Run inference +response = client.inference.chat_completion( + model_id="meta-llama/Llama-3.2-3B-Instruct", + messages=[{"role": "user", "content": "Hello, how are you?"}], +) +print(response.completion_message.content) +``` + +## Troubleshooting + +### Common Issues + +1. **Connection refused errors**: Ensure your Ollama server is running and accessible at the configured URL. + +2. **Model not found errors**: Make sure you've pulled the required models using `ollama pull `. + +3. **Performance issues**: Consider using more powerful models or adjusting the Ollama server configuration for better performance. + +### Logs + +Check the Ollama server logs for any issues: +```bash +# Ollama logs are typically available in: +# - macOS: ~/Library/Logs/Ollama/ +# - Linux: ~/.ollama/logs/ +``` \ No newline at end of file diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py new file mode 100644 index 000000000..83491da11 --- /dev/null +++ b/llama_stack/templates/ollama/ollama.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_stack.apis.models import ModelType +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) +from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig +from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig +from llama_stack.providers.remote.inference.ollama import OllamaImplConfig +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::ollama"], + "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + "eval": ["inline::meta-reference"], + "datasetio": ["remote::huggingface", "inline::localfs"], + "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "files": ["inline::localfs"], + "post_training": ["inline::huggingface"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::rag-runtime", + "remote::model-context-protocol", + "remote::wolfram-alpha", + ], + } + name = "ollama" + inference_provider = Provider( + provider_id="ollama", + provider_type="remote::ollama", + config=OllamaImplConfig.sample_run_config(), + ) + vector_io_provider_faiss = Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissVectorIOConfig.sample_run_config( + f"${{env.XDG_STATE_HOME:-~/.local/state}}/llama-stack/distributions/{name}" + ), + ) + files_provider = Provider( + provider_id="meta-reference-files", + provider_type="inline::localfs", + config=LocalfsFilesImplConfig.sample_run_config( + f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/distributions/{name}" + ), + ) + posttraining_provider = Provider( + provider_id="huggingface", + provider_type="inline::huggingface", + config=HuggingFacePostTrainingConfig.sample_run_config( + f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/distributions/{name}" + ), + ) + inference_model = ModelInput( + model_id="${env.INFERENCE_MODEL}", + provider_id="ollama", + ) + safety_model = ModelInput( + model_id="${env.SAFETY_MODEL}", + provider_id="ollama", + ) + embedding_model = ModelInput( + model_id="all-MiniLM-L6-v2", + provider_id="ollama", + provider_model_id="all-minilm:latest", + model_type=ModelType.embedding, + metadata={ + "embedding_dimension": 384, + }, + ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::wolfram_alpha", + provider_id="wolfram-alpha", + ), + ] + + return DistributionTemplate( + name=name, + distro_type="self_hosted", + description="Use (an external) Ollama server for running LLM inference", + container_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + "vector_io": [vector_io_provider_faiss], + "files": [files_provider], + "post_training": [posttraining_provider], + }, + default_models=[inference_model, embedding_model], + default_tool_groups=default_tool_groups, + ), + "run-with-safety.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + "vector_io": [vector_io_provider_faiss], + "files": [files_provider], + "post_training": [posttraining_provider], + "safety": [ + Provider( + provider_id="llama-guard", + provider_type="inline::llama-guard", + config={}, + ), + Provider( + provider_id="code-scanner", + provider_type="inline::code-scanner", + config={}, + ), + ], + }, + default_models=[ + inference_model, + safety_model, + embedding_model, + ], + default_shields=[ + ShieldInput( + shield_id="${env.SAFETY_MODEL}", + provider_id="llama-guard", + ), + ShieldInput( + shield_id="CodeScanner", + provider_id="code-scanner", + ), + ], + default_tool_groups=default_tool_groups, + ), + }, + run_config_env_vars={ + "LLAMA_STACK_PORT": ( + "8321", + "Port for the Llama Stack distribution server", + ), + "OLLAMA_URL": ( + "http://127.0.0.1:11434", + "URL of the Ollama server", + ), + "INFERENCE_MODEL": ( + "meta-llama/Llama-3.2-3B-Instruct", + "Inference model loaded into the Ollama server", + ), + "SAFETY_MODEL": ( + "meta-llama/Llama-Guard-3-1B", + "Safety model loaded into the Ollama server", + ), + }, + ) diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml new file mode 100644 index 000000000..27d426829 --- /dev/null +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -0,0 +1,158 @@ +version: 2 +image_name: ollama +apis: +- agents +- datasetio +- eval +- files +- inference +- post_training +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: ollama + provider_type: remote::ollama + config: + url: ${env.OLLAMA_URL:=http://localhost:11434} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + - provider_id: code-scanner + provider_type: inline::code-scanner + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console,sqlite} + sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db + otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:=} + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama}/files_metadata.db + post_training: + - provider_id: huggingface + provider_type: inline::huggingface + config: + checkpoint_format: huggingface + distributed_backend: null + device: cpu + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: + api_key: ${env.WOLFRAM_ALPHA_API_KEY:=} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/inference_store.db +models: +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: ollama + model_type: llm +- metadata: {} + model_id: ${env.SAFETY_MODEL} + provider_id: ollama + model_type: llm +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: ollama + provider_model_id: all-minilm:latest + model_type: embedding +shields: +- shield_id: ${env.SAFETY_MODEL} + provider_id: llama-guard +- shield_id: CodeScanner + provider_id: code-scanner +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha +server: + port: 8321 diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml new file mode 100644 index 000000000..c4d9a668f --- /dev/null +++ b/llama_stack/templates/ollama/run.yaml @@ -0,0 +1,148 @@ +version: 2 +image_name: ollama +apis: +- agents +- datasetio +- eval +- files +- inference +- post_training +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: ollama + provider_type: remote::ollama + config: + url: ${env.OLLAMA_URL:=http://localhost:11434} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console,sqlite} + sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db + otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:=} + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama}/files_metadata.db + post_training: + - provider_id: huggingface + provider_type: inline::huggingface + config: + checkpoint_format: huggingface + distributed_backend: null + device: cpu + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: + api_key: ${env.WOLFRAM_ALPHA_API_KEY:=} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/inference_store.db +models: +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: ollama + model_type: llm +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: ollama + provider_model_id: all-minilm:latest + model_type: embedding +shields: [] +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha +server: + port: 8321 diff --git a/scripts/provider_codegen.py b/scripts/provider_codegen.py index eff04a40f..ac2feb228 100755 --- a/scripts/provider_codegen.py +++ b/scripts/provider_codegen.py @@ -70,11 +70,44 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]: default_value = field.default_factory() # HACK ALERT: # If the default value contains a path that looks like it came from RUNTIME_BASE_DIR, - # replace it with a generic ~/.llama/ path for documentation - if isinstance(default_value, str) and "/.llama/" in default_value: - if ".llama/" in default_value: + # replace it with a generic XDG-compliant path for documentation + if isinstance(default_value, str): + # Handle legacy .llama/ paths + if "/.llama/" in default_value: path_part = default_value.split(".llama/")[-1] - default_value = f"~/.llama/{path_part}" + # Use appropriate XDG directory based on path content + if path_part.startswith("runtime/"): + default_value = f"${{env.XDG_STATE_HOME:-~/.local/state}}/llama-stack/{path_part}" + else: + default_value = f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/{path_part}" + # Handle XDG state paths (runtime data) + elif "/llama-stack/runtime/" in default_value: + path_part = default_value.split("/llama-stack/runtime/")[-1] + default_value = ( + f"${{env.XDG_STATE_HOME:-~/.local/state}}/llama-stack/runtime/{path_part}" + ) + # Handle XDG data paths + elif "/llama-stack/data/" in default_value or "/llama-stack/checkpoints/" in default_value: + if "/llama-stack/data/" in default_value: + path_part = default_value.split("/llama-stack/data/")[-1] + default_value = ( + f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/data/{path_part}" + ) + else: + path_part = default_value.split("/llama-stack/checkpoints/")[-1] + default_value = ( + f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/checkpoints/{path_part}" + ) + # Handle XDG config paths + elif "/llama-stack/" in default_value and ( + "/config/" in default_value or "/distributions/" in default_value + ): + if "/config/" in default_value: + path_part = default_value.split("/llama-stack/")[-1] + default_value = f"${{env.XDG_CONFIG_HOME:-~/.config}}/llama-stack/{path_part}" + else: + path_part = default_value.split("/llama-stack/")[-1] + default_value = f"${{env.XDG_CONFIG_HOME:-~/.config}}/llama-stack/{path_part}" except Exception: default_value = "" elif field.default is None: @@ -201,7 +234,9 @@ def generate_provider_docs(provider_spec: Any, api_name: str) -> str: if sample_config_func is not None: sig = inspect.signature(sample_config_func) if "__distro_dir__" in sig.parameters: - sample_config = sample_config_func(__distro_dir__="~/.llama/dummy") + sample_config = sample_config_func( + __distro_dir__="${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy" + ) else: sample_config = sample_config_func() diff --git a/tests/integration/test_xdg_e2e.py b/tests/integration/test_xdg_e2e.py new file mode 100644 index 000000000..29bc9f3f4 --- /dev/null +++ b/tests/integration/test_xdg_e2e.py @@ -0,0 +1,593 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +import yaml + +from llama_stack.distribution.utils.xdg_utils import ( + get_llama_stack_config_dir, + get_llama_stack_data_dir, + get_llama_stack_state_dir, +) + + +class TestXDGEndToEnd(unittest.TestCase): + """End-to-end tests for XDG compliance workflows.""" + + def setUp(self): + """Set up test environment.""" + self.original_env = {} + self.env_vars = [ + "XDG_CONFIG_HOME", + "XDG_DATA_HOME", + "XDG_STATE_HOME", + "XDG_CACHE_HOME", + "LLAMA_STACK_CONFIG_DIR", + "SQLITE_STORE_DIR", + "FILES_STORAGE_DIR", + ] + + for key in self.env_vars: + self.original_env[key] = os.environ.get(key) + + def tearDown(self): + """Clean up test environment.""" + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def clear_env_vars(self): + """Clear all relevant environment variables.""" + for key in self.env_vars: + os.environ.pop(key, None) + + def create_realistic_legacy_structure(self, base_dir: Path) -> Path: + """Create a realistic legacy ~/.llama directory structure.""" + legacy_dir = base_dir / ".llama" + legacy_dir.mkdir() + + # Create distributions with realistic content + distributions_dir = legacy_dir / "distributions" + distributions_dir.mkdir() + + # Ollama distribution + ollama_dir = distributions_dir / "ollama" + ollama_dir.mkdir() + + ollama_run_yaml = ollama_dir / "ollama-run.yaml" + ollama_run_yaml.write_text(""" +version: 2 +apis: + - inference + - safety + - memory + - vector_io + - agents + - files +providers: + inference: + - provider_type: remote::ollama + config: + url: http://localhost:11434 +""") + + ollama_build_yaml = ollama_dir / "build.yaml" + ollama_build_yaml.write_text(""" +name: ollama +description: Ollama inference provider +docker_image: ollama:latest +""") + + # Create providers.d structure + providers_dir = legacy_dir / "providers.d" + providers_dir.mkdir() + + remote_dir = providers_dir / "remote" + remote_dir.mkdir() + + inference_dir = remote_dir / "inference" + inference_dir.mkdir() + + custom_provider = inference_dir / "custom-inference.yaml" + custom_provider.write_text(""" +provider_type: remote::custom +config: + url: http://localhost:8080 + api_key: test_key +""") + + # Create checkpoints with model files + checkpoints_dir = legacy_dir / "checkpoints" + checkpoints_dir.mkdir() + + model_dir = checkpoints_dir / "meta-llama" / "Llama-3.2-1B-Instruct" + model_dir.mkdir(parents=True) + + # Create fake model files + (model_dir / "consolidated.00.pth").write_bytes(b"fake model weights" * 1000) + (model_dir / "params.json").write_text('{"dim": 2048, "n_layers": 22}') + (model_dir / "tokenizer.model").write_bytes(b"fake tokenizer" * 100) + + # Create runtime with databases + runtime_dir = legacy_dir / "runtime" + runtime_dir.mkdir() + + (runtime_dir / "trace_store.db").write_text("SQLite format 3\x00" + "fake database content") + (runtime_dir / "agent_sessions.db").write_text("SQLite format 3\x00" + "fake agent sessions") + + # Create config files + (legacy_dir / "config.json").write_text('{"version": "0.2.13", "last_updated": "2024-01-01"}') + + return legacy_dir + + def verify_xdg_migration_complete(self, base_dir: Path, legacy_dir: Path): + """Verify that migration to XDG structure is complete and correct.""" + config_dir = base_dir / ".config" / "llama-stack" + data_dir = base_dir / ".local" / "share" / "llama-stack" + state_dir = base_dir / ".local" / "state" / "llama-stack" + + # Verify distributions moved to config + self.assertTrue((config_dir / "distributions").exists()) + self.assertTrue((config_dir / "distributions" / "ollama").exists()) + self.assertTrue((config_dir / "distributions" / "ollama" / "ollama-run.yaml").exists()) + + # Verify YAML content is preserved + yaml_content = (config_dir / "distributions" / "ollama" / "ollama-run.yaml").read_text() + self.assertIn("version: 2", yaml_content) + self.assertIn("remote::ollama", yaml_content) + + # Verify providers.d moved to config + self.assertTrue((config_dir / "providers.d").exists()) + self.assertTrue((config_dir / "providers.d" / "remote" / "inference").exists()) + self.assertTrue((config_dir / "providers.d" / "remote" / "inference" / "custom-inference.yaml").exists()) + + # Verify checkpoints moved to data + self.assertTrue((data_dir / "checkpoints").exists()) + self.assertTrue((data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct").exists()) + self.assertTrue( + (data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct" / "consolidated.00.pth").exists() + ) + + # Verify model file content preserved + model_file = data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct" / "consolidated.00.pth" + self.assertGreater(model_file.stat().st_size, 1000) # Should be substantial size + + # Verify runtime moved to state + self.assertTrue((state_dir / "runtime").exists()) + self.assertTrue((state_dir / "runtime" / "trace_store.db").exists()) + self.assertTrue((state_dir / "runtime" / "agent_sessions.db").exists()) + + # Verify database files preserved + db_file = state_dir / "runtime" / "trace_store.db" + db_content = db_file.read_text() + self.assertIn("SQLite format 3", db_content) + + def test_fresh_installation_xdg_compliance(self): + """Test fresh installation uses XDG-compliant paths.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + # Set custom XDG paths + os.environ["XDG_CONFIG_HOME"] = str(base_dir / "custom_config") + os.environ["XDG_DATA_HOME"] = str(base_dir / "custom_data") + os.environ["XDG_STATE_HOME"] = str(base_dir / "custom_state") + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + # Mock that no legacy directory exists + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + # Fresh installation should use XDG paths + config_dir = get_llama_stack_config_dir() + data_dir = get_llama_stack_data_dir() + state_dir = get_llama_stack_state_dir() + + self.assertEqual(config_dir, base_dir / "custom_config" / "llama-stack") + self.assertEqual(data_dir, base_dir / "custom_data" / "llama-stack") + self.assertEqual(state_dir, base_dir / "custom_state" / "llama-stack") + + def test_complete_migration_workflow(self): + """Test complete migration workflow from legacy to XDG.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create realistic legacy structure + legacy_dir = self.create_realistic_legacy_structure(base_dir) + + # Verify legacy structure exists + self.assertTrue(legacy_dir.exists()) + self.assertTrue((legacy_dir / "distributions" / "ollama" / "ollama-run.yaml").exists()) + + # Perform migration + from llama_stack.cli.migrate_xdg import migrate_to_xdg + + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup + + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Verify migration completed successfully + self.verify_xdg_migration_complete(base_dir, legacy_dir) + + # Verify legacy directory was removed + self.assertFalse(legacy_dir.exists()) + + def test_migration_preserves_file_integrity(self): + """Test that migration preserves file integrity and permissions.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure + legacy_dir = self.create_realistic_legacy_structure(base_dir) + + # Set specific permissions on files + config_file = legacy_dir / "distributions" / "ollama" / "ollama-run.yaml" + config_file.chmod(0o600) + original_stat = config_file.stat() + + # Perform migration + from llama_stack.cli.migrate_xdg import migrate_to_xdg + + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Verify file integrity + migrated_config = base_dir / ".config" / "llama-stack" / "distributions" / "ollama" / "ollama-run.yaml" + self.assertTrue(migrated_config.exists()) + + # Verify content is identical + migrated_content = migrated_config.read_text() + self.assertIn("version: 2", migrated_content) + self.assertIn("remote::ollama", migrated_content) + + # Verify permissions preserved + migrated_stat = migrated_config.stat() + self.assertEqual(original_stat.st_mode, migrated_stat.st_mode) + + def test_mixed_legacy_and_xdg_environment(self): + """Test behavior in mixed legacy and XDG environment.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + # Set partial XDG environment + os.environ["XDG_CONFIG_HOME"] = str(base_dir / "xdg_config") + # Leave XDG_DATA_HOME and XDG_STATE_HOME unset + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy directory + legacy_dir = self.create_realistic_legacy_structure(base_dir) + + # Should use legacy directory since it exists + config_dir = get_llama_stack_config_dir() + data_dir = get_llama_stack_data_dir() + state_dir = get_llama_stack_state_dir() + + self.assertEqual(config_dir, legacy_dir) + self.assertEqual(data_dir, legacy_dir) + self.assertEqual(state_dir, legacy_dir) + + def test_template_rendering_with_xdg_paths(self): + """Test that templates render correctly with XDG paths.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + # Set XDG environment + os.environ["XDG_STATE_HOME"] = str(base_dir / "state") + os.environ["XDG_DATA_HOME"] = str(base_dir / "data") + + # Mock shell environment variable expansion + def mock_env_expand(template_string): + """Mock shell environment variable expansion.""" + result = template_string + result = result.replace("${env.XDG_STATE_HOME:-~/.local/state}", str(base_dir / "state")) + result = result.replace("${env.XDG_DATA_HOME:-~/.local/share}", str(base_dir / "data")) + return result + + # Test template patterns + template_patterns = [ + "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama", + "${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files", + ] + + expected_results = [ + str(base_dir / "state" / "llama-stack" / "distributions" / "ollama"), + str(base_dir / "data" / "llama-stack" / "distributions" / "ollama" / "files"), + ] + + for pattern, expected in zip(template_patterns, expected_results, strict=False): + with self.subTest(pattern=pattern): + expanded = mock_env_expand(pattern) + self.assertEqual(expanded, expected) + + def test_cli_integration_with_xdg_paths(self): + """Test CLI integration works correctly with XDG paths.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure + legacy_dir = self.create_realistic_legacy_structure(base_dir) + + # Test CLI migrate command + import argparse + + from llama_stack.cli.migrate_xdg import MigrateXDG + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + migrate_cmd = MigrateXDG.create(subparsers) + + # Test dry-run + args = parser.parse_args(["migrate-xdg", "--dry-run"]) + + with patch("builtins.print") as mock_print: + result = migrate_cmd._run_migrate_xdg_cmd(args) + self.assertEqual(result, 0) + + # Should print dry-run information + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Dry run mode" in call for call in print_calls)) + + # Legacy directory should still exist after dry-run + self.assertTrue(legacy_dir.exists()) + + def test_config_dirs_integration_after_migration(self): + """Test that config_dirs works correctly after migration.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create and migrate legacy structure + self.create_realistic_legacy_structure(base_dir) + + from llama_stack.cli.migrate_xdg import migrate_to_xdg + + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Clear module cache to ensure fresh import + import sys + + if "llama_stack.distribution.utils.config_dirs" in sys.modules: + del sys.modules["llama_stack.distribution.utils.config_dirs"] + + # Import config_dirs after migration + from llama_stack.distribution.utils.config_dirs import ( + DEFAULT_CHECKPOINT_DIR, + DISTRIBS_BASE_DIR, + LLAMA_STACK_CONFIG_DIR, + RUNTIME_BASE_DIR, + ) + + # Should use XDG paths + self.assertEqual(LLAMA_STACK_CONFIG_DIR, base_dir / ".config" / "llama-stack") + self.assertEqual(DEFAULT_CHECKPOINT_DIR, base_dir / ".local" / "share" / "llama-stack" / "checkpoints") + self.assertEqual(RUNTIME_BASE_DIR, base_dir / ".local" / "state" / "llama-stack" / "runtime") + self.assertEqual(DISTRIBS_BASE_DIR, base_dir / ".config" / "llama-stack" / "distributions") + + def test_real_file_operations_with_xdg_paths(self): + """Test real file operations work correctly with XDG paths.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + # Set XDG environment + os.environ["XDG_CONFIG_HOME"] = str(base_dir / "config") + os.environ["XDG_DATA_HOME"] = str(base_dir / "data") + os.environ["XDG_STATE_HOME"] = str(base_dir / "state") + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + # Get XDG paths + config_dir = get_llama_stack_config_dir() + data_dir = get_llama_stack_data_dir() + state_dir = get_llama_stack_state_dir() + + # Create directories + config_dir.mkdir(parents=True) + data_dir.mkdir(parents=True) + state_dir.mkdir(parents=True) + + # Test writing configuration files + config_file = config_dir / "test_config.yaml" + config_data = {"version": "2", "test": True} + + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + # Test reading configuration files + with open(config_file) as f: + loaded_config = yaml.safe_load(f) + + self.assertEqual(loaded_config, config_data) + + # Test creating nested directory structure + model_dir = data_dir / "checkpoints" / "meta-llama" / "test-model" + model_dir.mkdir(parents=True) + + # Test writing large files + model_file = model_dir / "model.bin" + test_data = b"test model data" * 1000 + model_file.write_bytes(test_data) + + # Verify file integrity + read_data = model_file.read_bytes() + self.assertEqual(read_data, test_data) + + # Test state files + state_file = state_dir / "runtime" / "session.db" + state_file.parent.mkdir(parents=True) + state_file.write_text("SQLite format 3\x00test database") + + # Verify state file + state_content = state_file.read_text() + self.assertIn("SQLite format 3", state_content) + + def test_backwards_compatibility_scenario(self): + """Test complete backwards compatibility scenario.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + # Scenario: User has existing legacy installation + legacy_dir = self.create_realistic_legacy_structure(base_dir) + + # User sets legacy environment variable + os.environ["LLAMA_STACK_CONFIG_DIR"] = str(legacy_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Should continue using legacy paths + config_dir = get_llama_stack_config_dir() + data_dir = get_llama_stack_data_dir() + state_dir = get_llama_stack_state_dir() + + self.assertEqual(config_dir, legacy_dir) + self.assertEqual(data_dir, legacy_dir) + self.assertEqual(state_dir, legacy_dir) + + # Should be able to access existing files + yaml_file = legacy_dir / "distributions" / "ollama" / "ollama-run.yaml" + self.assertTrue(yaml_file.exists()) + + # Should be able to parse existing configuration + with open(yaml_file) as f: + config = yaml.safe_load(f) + + self.assertEqual(config["version"], 2) + + def test_error_recovery_scenarios(self): + """Test error recovery in various scenarios.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Scenario 1: Partial migration failure + self.create_realistic_legacy_structure(base_dir) + + # Create conflicting file in target location + config_dir = base_dir / ".config" / "llama-stack" + config_dir.mkdir(parents=True) + + conflicting_file = config_dir / "distributions" + conflicting_file.touch() # Create file instead of directory + + from llama_stack.cli.migrate_xdg import migrate_to_xdg + + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup + + with patch("builtins.print") as mock_print: + result = migrate_to_xdg(dry_run=False) + + # Should handle conflicts gracefully + print_calls = [call[0][0] for call in mock_print.call_args_list] + conflict_mentioned = any( + "Warning" in call or "conflict" in call.lower() for call in print_calls + ) + + # Migration should complete with warnings + self.assertTrue(result or conflict_mentioned) + + def test_cross_platform_compatibility(self): + """Test cross-platform compatibility of XDG implementation.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + # Test with different path separators and formats + if os.name == "nt": # Windows + # Test Windows-style paths + os.environ["XDG_CONFIG_HOME"] = str(base_dir / "config").replace("/", "\\") + os.environ["XDG_DATA_HOME"] = str(base_dir / "data").replace("/", "\\") + else: # Unix-like + # Test Unix-style paths + os.environ["XDG_CONFIG_HOME"] = str(base_dir / "config") + os.environ["XDG_DATA_HOME"] = str(base_dir / "data") + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + # Should work regardless of platform + config_dir = get_llama_stack_config_dir() + data_dir = get_llama_stack_data_dir() + + # Paths should be valid for the current platform + self.assertTrue(config_dir.is_absolute()) + self.assertTrue(data_dir.is_absolute()) + + # Should be able to create directories + config_dir.mkdir(parents=True, exist_ok=True) + data_dir.mkdir(parents=True, exist_ok=True) + + self.assertTrue(config_dir.exists()) + self.assertTrue(data_dir.exists()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/integration/test_xdg_migration.py b/tests/integration/test_xdg_migration.py new file mode 100644 index 000000000..5da17f542 --- /dev/null +++ b/tests/integration/test_xdg_migration.py @@ -0,0 +1,516 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from llama_stack.cli.migrate_xdg import migrate_to_xdg + + +class TestXDGMigrationIntegration(unittest.TestCase): + """Integration tests for XDG migration functionality.""" + + def setUp(self): + """Set up test environment.""" + # Store original environment variables + self.original_env = {} + for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]: + self.original_env[key] = os.environ.get(key) + + # Clear environment variables + for key in self.original_env: + os.environ.pop(key, None) + + def tearDown(self): + """Clean up test environment.""" + # Restore original environment variables + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def create_legacy_structure(self, base_dir: Path) -> Path: + """Create a realistic legacy ~/.llama directory structure.""" + legacy_dir = base_dir / ".llama" + legacy_dir.mkdir() + + # Create distributions + distributions_dir = legacy_dir / "distributions" + distributions_dir.mkdir() + + # Create sample distribution + ollama_dir = distributions_dir / "ollama" + ollama_dir.mkdir() + (ollama_dir / "ollama-run.yaml").write_text("version: 2\napis: []\n") + (ollama_dir / "build.yaml").write_text("name: ollama\n") + + # Create providers.d + providers_dir = legacy_dir / "providers.d" + providers_dir.mkdir() + (providers_dir / "remote").mkdir() + (providers_dir / "remote" / "inference").mkdir() + (providers_dir / "remote" / "inference" / "custom.yaml").write_text("provider_type: remote::custom\n") + + # Create checkpoints + checkpoints_dir = legacy_dir / "checkpoints" + checkpoints_dir.mkdir() + model_dir = checkpoints_dir / "meta-llama" / "Llama-3.2-1B-Instruct" + model_dir.mkdir(parents=True) + (model_dir / "consolidated.00.pth").write_text("fake model weights") + (model_dir / "params.json").write_text('{"dim": 2048}') + + # Create runtime + runtime_dir = legacy_dir / "runtime" + runtime_dir.mkdir() + (runtime_dir / "trace_store.db").write_text("fake sqlite database") + + # Create some fake files in various subdirectories + (legacy_dir / "config.json").write_text('{"version": "0.2.13"}') + + return legacy_dir + + def verify_xdg_structure(self, base_dir: Path, legacy_dir: Path): + """Verify that the XDG structure was created correctly.""" + config_dir = base_dir / ".config" / "llama-stack" + data_dir = base_dir / ".local" / "share" / "llama-stack" + state_dir = base_dir / ".local" / "state" / "llama-stack" + + # Verify distributions moved to config + self.assertTrue((config_dir / "distributions").exists()) + self.assertTrue((config_dir / "distributions" / "ollama" / "ollama-run.yaml").exists()) + + # Verify providers.d moved to config + self.assertTrue((config_dir / "providers.d").exists()) + self.assertTrue((config_dir / "providers.d" / "remote" / "inference" / "custom.yaml").exists()) + + # Verify checkpoints moved to data + self.assertTrue((data_dir / "checkpoints").exists()) + self.assertTrue( + (data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct" / "consolidated.00.pth").exists() + ) + + # Verify runtime moved to state + self.assertTrue((state_dir / "runtime").exists()) + self.assertTrue((state_dir / "runtime" / "trace_store.db").exists()) + + def test_full_migration_workflow(self): + """Test complete migration workflow from legacy to XDG.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + # Set up fake home directory + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure + legacy_dir = self.create_legacy_structure(base_dir) + + # Verify legacy structure exists + self.assertTrue(legacy_dir.exists()) + self.assertTrue((legacy_dir / "distributions").exists()) + self.assertTrue((legacy_dir / "checkpoints").exists()) + + # Perform migration with user confirmation + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup + + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Verify XDG structure was created + self.verify_xdg_structure(base_dir, legacy_dir) + + # Verify legacy directory was removed + self.assertFalse(legacy_dir.exists()) + + def test_migration_dry_run(self): + """Test dry run migration (no actual file movement).""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure + legacy_dir = self.create_legacy_structure(base_dir) + + # Perform dry run migration + with patch("builtins.print") as mock_print: + result = migrate_to_xdg(dry_run=True) + self.assertTrue(result) + + # Check that dry run message was printed + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Dry run mode" in call for call in print_calls)) + + # Verify nothing was actually moved + self.assertTrue(legacy_dir.exists()) + self.assertTrue((legacy_dir / "distributions").exists()) + self.assertFalse((base_dir / ".config" / "llama-stack").exists()) + + def test_migration_user_cancellation(self): + """Test migration when user cancels.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure + legacy_dir = self.create_legacy_structure(base_dir) + + # User cancels migration + with patch("builtins.input") as mock_input: + mock_input.return_value = "n" + + result = migrate_to_xdg(dry_run=False) + self.assertFalse(result) + + # Verify nothing was moved + self.assertTrue(legacy_dir.exists()) + self.assertTrue((legacy_dir / "distributions").exists()) + self.assertFalse((base_dir / ".config" / "llama-stack").exists()) + + def test_migration_with_existing_xdg_directories(self): + """Test migration when XDG directories already exist.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure + legacy_dir = self.create_legacy_structure(base_dir) + + # Create existing XDG structure with conflicting files + config_dir = base_dir / ".config" / "llama-stack" + config_dir.mkdir(parents=True) + (config_dir / "distributions").mkdir() + (config_dir / "distributions" / "existing.yaml").write_text("existing config") + + # Perform migration + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup + with patch("builtins.print") as mock_print: + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Check that warning was printed + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Warning: Target already exists" in call for call in print_calls)) + + # Verify existing file wasn't overwritten + self.assertTrue((config_dir / "distributions" / "existing.yaml").exists()) + + # Legacy distributions should still exist due to conflict + self.assertTrue((legacy_dir / "distributions").exists()) + + def test_migration_partial_success(self): + """Test migration when some items succeed and others fail.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure + self.create_legacy_structure(base_dir) + + # Create readonly target directory to simulate permission error + config_dir = base_dir / ".config" / "llama-stack" + config_dir.mkdir(parents=True) + distributions_target = config_dir / "distributions" + distributions_target.mkdir() + distributions_target.chmod(0o444) # Read-only + + try: + # Perform migration + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup + + migrate_to_xdg(dry_run=False) + # Should return True even with partial success + + # Some items should have been migrated successfully + self.assertTrue((base_dir / ".local" / "share" / "llama-stack" / "checkpoints").exists()) + + finally: + # Restore permissions for cleanup + distributions_target.chmod(0o755) + + def test_migration_empty_legacy_directory(self): + """Test migration when legacy directory exists but is empty.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create empty legacy directory + legacy_dir = base_dir / ".llama" + legacy_dir.mkdir() + + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + def test_migration_preserves_file_permissions(self): + """Test that migration preserves file permissions.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure with specific permissions + legacy_dir = self.create_legacy_structure(base_dir) + + # Set specific permissions on a file + config_file = legacy_dir / "distributions" / "ollama" / "ollama-run.yaml" + config_file.chmod(0o600) + original_stat = config_file.stat() + + # Perform migration + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Verify permissions were preserved + migrated_file = base_dir / ".config" / "llama-stack" / "distributions" / "ollama" / "ollama-run.yaml" + self.assertTrue(migrated_file.exists()) + migrated_stat = migrated_file.stat() + self.assertEqual(original_stat.st_mode, migrated_stat.st_mode) + + def test_migration_preserves_directory_structure(self): + """Test that migration preserves complex directory structures.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create complex legacy structure + legacy_dir = base_dir / ".llama" + legacy_dir.mkdir() + + # Create nested structure + complex_path = legacy_dir / "checkpoints" / "org" / "model" / "variant" / "files" + complex_path.mkdir(parents=True) + (complex_path / "model.bin").write_text("model data") + (complex_path / "config.json").write_text("config data") + + # Perform migration + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Verify structure was preserved + migrated_path = ( + base_dir + / ".local" + / "share" + / "llama-stack" + / "checkpoints" + / "org" + / "model" + / "variant" + / "files" + ) + self.assertTrue(migrated_path.exists()) + self.assertTrue((migrated_path / "model.bin").exists()) + self.assertTrue((migrated_path / "config.json").exists()) + + def test_migration_with_symlinks(self): + """Test migration with symbolic links in legacy directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure + legacy_dir = self.create_legacy_structure(base_dir) + + # Create a symlink + actual_file = legacy_dir / "actual_config.yaml" + actual_file.write_text("actual config content") + + symlink_file = legacy_dir / "distributions" / "symlinked_config.yaml" + symlink_file.symlink_to(actual_file) + + # Perform migration + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Verify symlink was preserved + migrated_symlink = base_dir / ".config" / "llama-stack" / "distributions" / "symlinked_config.yaml" + self.assertTrue(migrated_symlink.exists()) + self.assertTrue(migrated_symlink.is_symlink()) + + def test_migration_large_files(self): + """Test migration with large files (simulated).""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure + legacy_dir = self.create_legacy_structure(base_dir) + + # Create a larger file (1MB) + large_file = legacy_dir / "checkpoints" / "large_model.bin" + large_file.write_bytes(b"0" * (1024 * 1024)) + + # Perform migration + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Verify large file was moved correctly + migrated_file = base_dir / ".local" / "share" / "llama-stack" / "checkpoints" / "large_model.bin" + self.assertTrue(migrated_file.exists()) + self.assertEqual(migrated_file.stat().st_size, 1024 * 1024) + + def test_migration_with_unicode_filenames(self): + """Test migration with unicode filenames.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy structure with unicode filenames + legacy_dir = base_dir / ".llama" + legacy_dir.mkdir() + + unicode_dir = legacy_dir / "distributions" / "配置" + unicode_dir.mkdir(parents=True) + unicode_file = unicode_dir / "模型.yaml" + unicode_file.write_text("unicode content") + + # Perform migration + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Verify unicode files were migrated + migrated_file = base_dir / ".config" / "llama-stack" / "distributions" / "配置" / "模型.yaml" + self.assertTrue(migrated_file.exists()) + self.assertEqual(migrated_file.read_text(), "unicode content") + + +class TestXDGMigrationCLI(unittest.TestCase): + """Test the CLI interface for XDG migration.""" + + def setUp(self): + """Set up test environment.""" + self.original_env = {} + for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]: + self.original_env[key] = os.environ.get(key) + os.environ.pop(key, None) + + def tearDown(self): + """Clean up test environment.""" + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def test_cli_migrate_command_exists(self): + """Test that the migrate-xdg CLI command is properly registered.""" + from llama_stack.cli.llama import LlamaCLIParser + + parser = LlamaCLIParser() + + # Parse help to see if migrate-xdg is listed + with patch("sys.argv", ["llama", "--help"]): + with patch("sys.exit"): + with patch("builtins.print") as mock_print: + try: + parser.parse_args() + except SystemExit: + pass + + # Check if migrate-xdg appears in help output + help_output = "\n".join([call[0][0] for call in mock_print.call_args_list]) + self.assertIn("migrate-xdg", help_output) + + def test_cli_migrate_dry_run(self): + """Test CLI migrate command with dry-run flag.""" + import argparse + + from llama_stack.cli.migrate_xdg import MigrateXDG + + # Create parser and add migrate command + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + MigrateXDG.create(subparsers) + + # Test dry-run flag + args = parser.parse_args(["migrate-xdg", "--dry-run"]) + self.assertTrue(args.dry_run) + + # Test without dry-run flag + args = parser.parse_args(["migrate-xdg"]) + self.assertFalse(args.dry_run) + + def test_cli_migrate_execution(self): + """Test CLI migrate command execution.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create legacy directory + legacy_dir = base_dir / ".llama" + legacy_dir.mkdir() + (legacy_dir / "test_file").touch() + + import argparse + + from llama_stack.cli.migrate_xdg import MigrateXDG + + # Create parser and command + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + migrate_cmd = MigrateXDG.create(subparsers) + + # Parse arguments + args = parser.parse_args(["migrate-xdg", "--dry-run"]) + + # Execute command + with patch("builtins.print") as mock_print: + migrate_cmd._run_migrate_xdg_cmd(args) + + # Verify output was printed + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Found legacy directory" in call for call in print_calls)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/run_xdg_tests.py b/tests/run_xdg_tests.py new file mode 100755 index 000000000..26a2ff1f3 --- /dev/null +++ b/tests/run_xdg_tests.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Test runner for XDG Base Directory Specification compliance tests. + +This script runs all XDG-related tests and provides a comprehensive report +of the test results. +""" + +import sys +import unittest +from pathlib import Path + + +def run_test_suite(): + """Run the complete XDG test suite.""" + + # Set up test environment + test_dir = Path(__file__).parent + project_root = test_dir.parent + + # Add project root to Python path + sys.path.insert(0, str(project_root)) + + # Test modules to run + test_modules = [ + "tests.unit.test_xdg_compliance", + "tests.unit.test_config_dirs", + "tests.unit.cli.test_migrate_xdg", + "tests.unit.test_template_xdg_paths", + "tests.unit.test_xdg_edge_cases", + "tests.integration.test_xdg_migration", + "tests.integration.test_xdg_e2e", + ] + + # Discover and run tests + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + print("🔍 Discovering XDG compliance tests...") + + for module_name in test_modules: + try: + # Try to load the test module + module_suite = loader.loadTestsFromName(module_name) + suite.addTest(module_suite) + print(f" ✅ Loaded {module_name}") + except Exception as e: + print(f" ⚠️ Failed to load {module_name}: {e}") + + # Run the tests + print("\n🧪 Running XDG compliance tests...") + print("=" * 60) + + runner = unittest.TextTestRunner(verbosity=2, stream=sys.stdout, descriptions=True, buffer=True) + + result = runner.run(suite) + + # Print summary + print("\n" + "=" * 60) + print("📊 Test Summary:") + print(f" Tests run: {result.testsRun}") + print(f" Failures: {len(result.failures)}") + print(f" Errors: {len(result.errors)}") + print(f" Skipped: {len(result.skipped)}") + + if result.failures: + print("\n❌ Failures:") + for test, traceback in result.failures: + print(f" - {test}: {traceback.split('AssertionError:')[-1].strip()}") + + if result.errors: + print("\n💥 Errors:") + for test, traceback in result.errors: + print(f" - {test}: {traceback.split('Exception:')[-1].strip()}") + + if result.skipped: + print("\n⏭️ Skipped:") + for test, reason in result.skipped: + print(f" - {test}: {reason}") + + # Overall result + if result.wasSuccessful(): + print("\n🎉 All XDG compliance tests passed!") + return 0 + else: + print("\n⚠️ Some XDG compliance tests failed.") + return 1 + + +def run_quick_tests(): + """Run a quick subset of critical XDG tests.""" + + print("🚀 Running quick XDG compliance tests...") + + # Add project root to Python path + project_root = Path(__file__).parent.parent + sys.path.insert(0, str(project_root)) + + # Quick test: Basic XDG functionality + try: + from llama_stack.distribution.utils.xdg_utils import ( + get_llama_stack_config_dir, + get_xdg_compliant_path, + get_xdg_config_home, + ) + + print(" ✅ XDG utilities import successfully") + + # Test basic functionality + config_home = get_xdg_config_home() + llama_config = get_llama_stack_config_dir() + compliant_path = get_xdg_compliant_path("config", "test") + + print(f" ✅ XDG config home: {config_home}") + print(f" ✅ Llama Stack config: {llama_config}") + print(f" ✅ Compliant path: {compliant_path}") + + except Exception as e: + print(f" ❌ XDG utilities failed: {e}") + return 1 + + # Quick test: Config dirs integration + try: + from llama_stack.distribution.utils.config_dirs import ( + DEFAULT_CHECKPOINT_DIR, + LLAMA_STACK_CONFIG_DIR, + ) + + print(f" ✅ Config dirs integration: {LLAMA_STACK_CONFIG_DIR}") + print(f" ✅ Checkpoint directory: {DEFAULT_CHECKPOINT_DIR}") + + except Exception as e: + print(f" ❌ Config dirs integration failed: {e}") + return 1 + + # Quick test: CLI migrate command + try: + print(" ✅ CLI migrate command available") + + except Exception as e: + print(f" ❌ CLI migrate command failed: {e}") + return 1 + + print("\n🎉 Quick XDG compliance tests passed!") + return 0 + + +def main(): + """Main test runner entry point.""" + + if len(sys.argv) > 1 and sys.argv[1] == "--quick": + return run_quick_tests() + else: + return run_test_suite() + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/unit/cli/test_migrate_xdg.py b/tests/unit/cli/test_migrate_xdg.py new file mode 100644 index 000000000..b3ebd69e0 --- /dev/null +++ b/tests/unit/cli/test_migrate_xdg.py @@ -0,0 +1,489 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import os +import tempfile +import unittest +from io import StringIO +from pathlib import Path +from unittest.mock import patch + +from llama_stack.cli.migrate_xdg import MigrateXDG, migrate_to_xdg + + +class TestMigrateXDGCLI(unittest.TestCase): + """Test the MigrateXDG CLI command.""" + + def setUp(self): + """Set up test environment.""" + self.original_env = {} + for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]: + self.original_env[key] = os.environ.get(key) + os.environ.pop(key, None) + + def tearDown(self): + """Clean up test environment.""" + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def create_parser_with_migrate_cmd(self): + """Create parser with migrate-xdg command.""" + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command") + migrate_cmd = MigrateXDG.create(subparsers) + return parser, migrate_cmd + + def test_migrate_xdg_command_creation(self): + """Test that MigrateXDG command can be created.""" + parser, migrate_cmd = self.create_parser_with_migrate_cmd() + + self.assertIsInstance(migrate_cmd, MigrateXDG) + self.assertEqual(migrate_cmd.parser.prog, "llama migrate-xdg") + self.assertEqual(migrate_cmd.parser.description, "Migrate from legacy ~/.llama to XDG-compliant directories") + + def test_migrate_xdg_argument_parsing(self): + """Test argument parsing for migrate-xdg command.""" + parser, _ = self.create_parser_with_migrate_cmd() + + # Test with dry-run flag + args = parser.parse_args(["migrate-xdg", "--dry-run"]) + self.assertEqual(args.command, "migrate-xdg") + self.assertTrue(args.dry_run) + + # Test without dry-run flag + args = parser.parse_args(["migrate-xdg"]) + self.assertEqual(args.command, "migrate-xdg") + self.assertFalse(args.dry_run) + + def test_migrate_xdg_help_text(self): + """Test help text for migrate-xdg command.""" + parser, _ = self.create_parser_with_migrate_cmd() + + # Capture help output + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + with patch("sys.exit"): + try: + parser.parse_args(["migrate-xdg", "--help"]) + except SystemExit: + pass + + help_text = mock_stdout.getvalue() + self.assertIn("migrate-xdg", help_text) + self.assertIn("XDG-compliant directories", help_text) + self.assertIn("--dry-run", help_text) + + def test_migrate_xdg_command_execution_no_legacy(self): + """Test command execution when no legacy directory exists.""" + parser, migrate_cmd = self.create_parser_with_migrate_cmd() + + with tempfile.TemporaryDirectory() as temp_dir: + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = Path(temp_dir) + + args = parser.parse_args(["migrate-xdg"]) + + with patch("builtins.print") as mock_print: + result = migrate_cmd._run_migrate_xdg_cmd(args) + + # Should succeed when no migration needed + self.assertEqual(result, 0) + + # Should print appropriate message + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("No legacy directory found" in call for call in print_calls)) + + def test_migrate_xdg_command_execution_with_legacy(self): + """Test command execution when legacy directory exists.""" + parser, migrate_cmd = self.create_parser_with_migrate_cmd() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = base_dir / ".llama" + legacy_dir.mkdir() + (legacy_dir / "test_file").touch() + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + args = parser.parse_args(["migrate-xdg"]) + + with patch("builtins.print") as mock_print: + result = migrate_cmd._run_migrate_xdg_cmd(args) + + # Should succeed + self.assertEqual(result, 0) + + # Should print migration information + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Found legacy directory" in call for call in print_calls)) + + def test_migrate_xdg_command_execution_dry_run(self): + """Test command execution with dry-run flag.""" + parser, migrate_cmd = self.create_parser_with_migrate_cmd() + + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = base_dir / ".llama" + legacy_dir.mkdir() + (legacy_dir / "test_file").touch() + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + args = parser.parse_args(["migrate-xdg", "--dry-run"]) + + with patch("builtins.print") as mock_print: + result = migrate_cmd._run_migrate_xdg_cmd(args) + + # Should succeed + self.assertEqual(result, 0) + + # Should print dry-run information + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Dry run mode" in call for call in print_calls)) + + def test_migrate_xdg_command_execution_error_handling(self): + """Test command execution with error handling.""" + parser, migrate_cmd = self.create_parser_with_migrate_cmd() + + args = parser.parse_args(["migrate-xdg"]) + + # Mock migrate_to_xdg to raise an exception + with patch("llama_stack.cli.migrate_xdg.migrate_to_xdg") as mock_migrate: + mock_migrate.side_effect = Exception("Test error") + + with patch("builtins.print") as mock_print: + result = migrate_cmd._run_migrate_xdg_cmd(args) + + # Should return error code + self.assertEqual(result, 1) + + # Should print error message + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Error during migration" in call for call in print_calls)) + + def test_migrate_xdg_command_integration(self): + """Test full integration of migrate-xdg command.""" + from llama_stack.cli.llama import LlamaCLIParser + + # Create main parser + main_parser = LlamaCLIParser() + + # Test that migrate-xdg is in the subcommands + with patch("sys.argv", ["llama", "migrate-xdg", "--help"]): + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + with patch("sys.exit"): + try: + main_parser.parse_args() + except SystemExit: + pass + + help_text = mock_stdout.getvalue() + self.assertIn("migrate-xdg", help_text) + + +class TestMigrateXDGFunction(unittest.TestCase): + """Test the migrate_to_xdg function directly.""" + + def setUp(self): + """Set up test environment.""" + self.original_env = {} + for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]: + self.original_env[key] = os.environ.get(key) + os.environ.pop(key, None) + + def tearDown(self): + """Clean up test environment.""" + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def create_legacy_structure(self, base_dir: Path) -> Path: + """Create a test legacy directory structure.""" + legacy_dir = base_dir / ".llama" + legacy_dir.mkdir() + + # Create distributions + (legacy_dir / "distributions").mkdir() + (legacy_dir / "distributions" / "ollama").mkdir() + (legacy_dir / "distributions" / "ollama" / "run.yaml").write_text("version: 2\n") + + # Create checkpoints + (legacy_dir / "checkpoints").mkdir() + (legacy_dir / "checkpoints" / "model.bin").write_text("fake model") + + # Create providers.d + (legacy_dir / "providers.d").mkdir() + (legacy_dir / "providers.d" / "provider.yaml").write_text("provider: test\n") + + # Create runtime + (legacy_dir / "runtime").mkdir() + (legacy_dir / "runtime" / "trace.db").write_text("fake database") + + return legacy_dir + + def test_migrate_to_xdg_no_legacy_directory(self): + """Test migrate_to_xdg when no legacy directory exists.""" + with tempfile.TemporaryDirectory() as temp_dir: + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = Path(temp_dir) + + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + def test_migrate_to_xdg_dry_run(self): + """Test migrate_to_xdg with dry_run=True.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = self.create_legacy_structure(base_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + with patch("builtins.print") as mock_print: + result = migrate_to_xdg(dry_run=True) + self.assertTrue(result) + + # Should print dry run information + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Dry run mode" in call for call in print_calls)) + + # Legacy directory should still exist + self.assertTrue(legacy_dir.exists()) + + def test_migrate_to_xdg_user_confirms(self): + """Test migrate_to_xdg when user confirms migration.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = self.create_legacy_structure(base_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup + + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Legacy directory should be removed + self.assertFalse(legacy_dir.exists()) + + # XDG directories should be created + self.assertTrue((base_dir / ".config" / "llama-stack").exists()) + self.assertTrue((base_dir / ".local" / "share" / "llama-stack").exists()) + + def test_migrate_to_xdg_user_cancels(self): + """Test migrate_to_xdg when user cancels migration.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = self.create_legacy_structure(base_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + with patch("builtins.input") as mock_input: + mock_input.return_value = "n" # Cancel migration + + result = migrate_to_xdg(dry_run=False) + self.assertFalse(result) + + # Legacy directory should still exist + self.assertTrue(legacy_dir.exists()) + + def test_migrate_to_xdg_partial_migration(self): + """Test migrate_to_xdg with partial migration (some files fail).""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = self.create_legacy_structure(base_dir) + self.assertFalse(legacy_dir.exists()) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create target directory with conflicting file + config_dir = base_dir / ".config" / "llama-stack" + config_dir.mkdir(parents=True) + (config_dir / "distributions").mkdir() + (config_dir / "distributions" / "existing.yaml").write_text("existing") + + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup + + with patch("builtins.print") as mock_print: + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Should print warning about conflicts + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Warning: Target already exists" in call for call in print_calls)) + + def test_migrate_to_xdg_permission_error(self): + """Test migrate_to_xdg with permission errors.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = self.create_legacy_structure(base_dir) + self.assertFalse(legacy_dir.exists()) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Create readonly target directory + config_dir = base_dir / ".config" / "llama-stack" + config_dir.mkdir(parents=True) + config_dir.chmod(0o444) # Read-only + + try: + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup + + with patch("builtins.print") as mock_print: + result = migrate_to_xdg(dry_run=False) + self.assertFalse(result) + + # Should handle permission errors gracefully + print_calls = [call[0][0] for call in mock_print.call_args_list] + # Should contain some error or warning message + self.assertTrue(len(print_calls) > 0) + + finally: + # Restore permissions for cleanup + config_dir.chmod(0o755) + + def test_migrate_to_xdg_empty_legacy_directory(self): + """Test migrate_to_xdg with empty legacy directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = base_dir / ".llama" + legacy_dir.mkdir() # Empty directory + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + def test_migrate_to_xdg_preserves_file_content(self): + """Test that migrate_to_xdg preserves file content correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = self.create_legacy_structure(base_dir) + + # Add specific content to test + test_content = "test configuration content" + (legacy_dir / "distributions" / "ollama" / "run.yaml").write_text(test_content) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup + + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Check content was preserved + migrated_file = base_dir / ".config" / "llama-stack" / "distributions" / "ollama" / "run.yaml" + self.assertTrue(migrated_file.exists()) + self.assertEqual(migrated_file.read_text(), test_content) + + def test_migrate_to_xdg_with_symlinks(self): + """Test migrate_to_xdg with symbolic links.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = self.create_legacy_structure(base_dir) + + # Create symlink + actual_file = legacy_dir / "actual_config.yaml" + actual_file.write_text("actual config") + + symlink_file = legacy_dir / "distributions" / "symlinked.yaml" + symlink_file.symlink_to(actual_file) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup + + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Check symlink was preserved + migrated_symlink = base_dir / ".config" / "llama-stack" / "distributions" / "symlinked.yaml" + self.assertTrue(migrated_symlink.exists()) + self.assertTrue(migrated_symlink.is_symlink()) + + def test_migrate_to_xdg_nested_directory_structure(self): + """Test migrate_to_xdg with nested directory structures.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = self.create_legacy_structure(base_dir) + + # Create nested structure + nested_dir = legacy_dir / "checkpoints" / "org" / "model" / "variant" + nested_dir.mkdir(parents=True) + (nested_dir / "model.bin").write_text("nested model") + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + with patch("builtins.input") as mock_input: + mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup + + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result) + + # Check nested structure was preserved + migrated_nested = ( + base_dir / ".local" / "share" / "llama-stack" / "checkpoints" / "org" / "model" / "variant" + ) + self.assertTrue(migrated_nested.exists()) + self.assertTrue((migrated_nested / "model.bin").exists()) + + def test_migrate_to_xdg_user_input_variations(self): + """Test migrate_to_xdg with various user input variations.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) + legacy_dir = self.create_legacy_structure(base_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = base_dir + + # Test various forms of "yes" + for yes_input in ["y", "Y", "yes", "Yes", "YES"]: + # Recreate legacy directory for each test + if legacy_dir.exists(): + import shutil + + shutil.rmtree(legacy_dir) + self.create_legacy_structure(base_dir) + + with patch("builtins.input") as mock_input: + mock_input.side_effect = [yes_input, "n"] # Confirm migration, don't cleanup + + result = migrate_to_xdg(dry_run=False) + self.assertTrue(result, f"Failed with input: {yes_input}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_config_dirs.py b/tests/unit/test_config_dirs.py new file mode 100644 index 000000000..0cc4469b6 --- /dev/null +++ b/tests/unit/test_config_dirs.py @@ -0,0 +1,418 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + + +# Import after we set up environment to avoid module-level imports affecting tests +class TestConfigDirs(unittest.TestCase): + """Test the config_dirs module with XDG compliance and backwards compatibility.""" + + def setUp(self): + """Set up test environment.""" + # Store original environment variables + self.original_env = {} + self.env_vars = [ + "XDG_CONFIG_HOME", + "XDG_DATA_HOME", + "XDG_STATE_HOME", + "XDG_CACHE_HOME", + "LLAMA_STACK_CONFIG_DIR", + "SQLITE_STORE_DIR", + "FILES_STORAGE_DIR", + ] + + for key in self.env_vars: + self.original_env[key] = os.environ.get(key) + + def tearDown(self): + """Clean up test environment.""" + # Restore original environment variables + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + # Clear module cache to ensure fresh imports + import sys + + modules_to_clear = ["llama_stack.distribution.utils.config_dirs", "llama_stack.distribution.utils.xdg_utils"] + for module in modules_to_clear: + if module in sys.modules: + del sys.modules[module] + + def clear_env_vars(self): + """Clear all relevant environment variables.""" + for key in self.env_vars: + os.environ.pop(key, None) + + def test_config_dirs_xdg_defaults(self): + """Test config_dirs with XDG default paths.""" + self.clear_env_vars() + + # Mock that no legacy directory exists + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = Path("/home/testuser") + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + # Import after setting up mocks + from llama_stack.distribution.utils.config_dirs import ( + DEFAULT_CHECKPOINT_DIR, + DISTRIBS_BASE_DIR, + EXTERNAL_PROVIDERS_DIR, + LLAMA_STACK_CONFIG_DIR, + RUNTIME_BASE_DIR, + ) + + # Verify XDG-compliant paths + self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/home/testuser/.config/llama-stack")) + self.assertEqual(DEFAULT_CHECKPOINT_DIR, Path("/home/testuser/.local/share/llama-stack/checkpoints")) + self.assertEqual(RUNTIME_BASE_DIR, Path("/home/testuser/.local/state/llama-stack/runtime")) + self.assertEqual(EXTERNAL_PROVIDERS_DIR, Path("/home/testuser/.config/llama-stack/providers.d")) + self.assertEqual(DISTRIBS_BASE_DIR, Path("/home/testuser/.config/llama-stack/distributions")) + + def test_config_dirs_custom_xdg_paths(self): + """Test config_dirs with custom XDG paths.""" + self.clear_env_vars() + + # Set custom XDG paths + os.environ["XDG_CONFIG_HOME"] = "/custom/config" + os.environ["XDG_DATA_HOME"] = "/custom/data" + os.environ["XDG_STATE_HOME"] = "/custom/state" + + # Mock that no legacy directory exists + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + from llama_stack.distribution.utils.config_dirs import ( + DEFAULT_CHECKPOINT_DIR, + DISTRIBS_BASE_DIR, + EXTERNAL_PROVIDERS_DIR, + LLAMA_STACK_CONFIG_DIR, + RUNTIME_BASE_DIR, + ) + + # Verify custom XDG paths are used + self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/custom/config/llama-stack")) + self.assertEqual(DEFAULT_CHECKPOINT_DIR, Path("/custom/data/llama-stack/checkpoints")) + self.assertEqual(RUNTIME_BASE_DIR, Path("/custom/state/llama-stack/runtime")) + self.assertEqual(EXTERNAL_PROVIDERS_DIR, Path("/custom/config/llama-stack/providers.d")) + self.assertEqual(DISTRIBS_BASE_DIR, Path("/custom/config/llama-stack/distributions")) + + def test_config_dirs_legacy_environment_variable(self): + """Test config_dirs with legacy LLAMA_STACK_CONFIG_DIR.""" + self.clear_env_vars() + + # Set legacy environment variable + os.environ["LLAMA_STACK_CONFIG_DIR"] = "/legacy/llama" + + from llama_stack.distribution.utils.config_dirs import ( + DEFAULT_CHECKPOINT_DIR, + DISTRIBS_BASE_DIR, + EXTERNAL_PROVIDERS_DIR, + LLAMA_STACK_CONFIG_DIR, + RUNTIME_BASE_DIR, + ) + + # All paths should use the legacy base + legacy_base = Path("/legacy/llama") + self.assertEqual(LLAMA_STACK_CONFIG_DIR, legacy_base) + self.assertEqual(DEFAULT_CHECKPOINT_DIR, legacy_base / "checkpoints") + self.assertEqual(RUNTIME_BASE_DIR, legacy_base / "runtime") + self.assertEqual(EXTERNAL_PROVIDERS_DIR, legacy_base / "providers.d") + self.assertEqual(DISTRIBS_BASE_DIR, legacy_base / "distributions") + + def test_config_dirs_legacy_directory_exists(self): + """Test config_dirs when legacy ~/.llama directory exists.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + home_dir = Path(temp_dir) + legacy_dir = home_dir / ".llama" + legacy_dir.mkdir() + (legacy_dir / "test_file").touch() # Add content + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = home_dir + + from llama_stack.distribution.utils.config_dirs import ( + DEFAULT_CHECKPOINT_DIR, + DISTRIBS_BASE_DIR, + EXTERNAL_PROVIDERS_DIR, + LLAMA_STACK_CONFIG_DIR, + RUNTIME_BASE_DIR, + ) + + # Should use legacy directory + self.assertEqual(LLAMA_STACK_CONFIG_DIR, legacy_dir) + self.assertEqual(DEFAULT_CHECKPOINT_DIR, legacy_dir / "checkpoints") + self.assertEqual(RUNTIME_BASE_DIR, legacy_dir / "runtime") + self.assertEqual(EXTERNAL_PROVIDERS_DIR, legacy_dir / "providers.d") + self.assertEqual(DISTRIBS_BASE_DIR, legacy_dir / "distributions") + + def test_config_dirs_precedence_order(self): + """Test precedence order: LLAMA_STACK_CONFIG_DIR > legacy directory > XDG.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + home_dir = Path(temp_dir) + legacy_dir = home_dir / ".llama" + legacy_dir.mkdir() + (legacy_dir / "test_file").touch() + + # Set both legacy env var and XDG vars + os.environ["LLAMA_STACK_CONFIG_DIR"] = "/priority/path" + os.environ["XDG_CONFIG_HOME"] = "/custom/config" + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = home_dir + + from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR + + # Legacy env var should take precedence + self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/priority/path")) + + def test_config_dirs_all_path_types(self): + """Test that all path objects are of correct type and absolute.""" + self.clear_env_vars() + + from llama_stack.distribution.utils.config_dirs import ( + DEFAULT_CHECKPOINT_DIR, + DISTRIBS_BASE_DIR, + EXTERNAL_PROVIDERS_DIR, + LLAMA_STACK_CONFIG_DIR, + RUNTIME_BASE_DIR, + ) + + # All should be Path objects + paths = [ + LLAMA_STACK_CONFIG_DIR, + DEFAULT_CHECKPOINT_DIR, + RUNTIME_BASE_DIR, + EXTERNAL_PROVIDERS_DIR, + DISTRIBS_BASE_DIR, + ] + + for path in paths: + self.assertIsInstance(path, Path, f"Path {path} should be Path object") + self.assertTrue(path.is_absolute(), f"Path {path} should be absolute") + + def test_config_dirs_directory_relationships(self): + """Test relationships between different directory paths.""" + self.clear_env_vars() + + from llama_stack.distribution.utils.config_dirs import ( + DISTRIBS_BASE_DIR, + EXTERNAL_PROVIDERS_DIR, + LLAMA_STACK_CONFIG_DIR, + ) + + # Test parent-child relationships + self.assertEqual(EXTERNAL_PROVIDERS_DIR.parent, LLAMA_STACK_CONFIG_DIR) + self.assertEqual(DISTRIBS_BASE_DIR.parent, LLAMA_STACK_CONFIG_DIR) + + # Test expected subdirectory names + self.assertEqual(EXTERNAL_PROVIDERS_DIR.name, "providers.d") + self.assertEqual(DISTRIBS_BASE_DIR.name, "distributions") + + def test_config_dirs_environment_isolation(self): + """Test that config_dirs is properly isolated between tests.""" + self.clear_env_vars() + + # First import with one set of environment variables + os.environ["LLAMA_STACK_CONFIG_DIR"] = "/first/path" + + # Clear module cache + import sys + + if "llama_stack.distribution.utils.config_dirs" in sys.modules: + del sys.modules["llama_stack.distribution.utils.config_dirs"] + + from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as FIRST_CONFIG + + # Change environment and re-import + os.environ["LLAMA_STACK_CONFIG_DIR"] = "/second/path" + + # Clear module cache again + if "llama_stack.distribution.utils.config_dirs" in sys.modules: + del sys.modules["llama_stack.distribution.utils.config_dirs"] + + from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as SECOND_CONFIG + + # Should get different paths + self.assertEqual(FIRST_CONFIG, Path("/first/path")) + self.assertEqual(SECOND_CONFIG, Path("/second/path")) + + def test_config_dirs_with_tilde_expansion(self): + """Test config_dirs with tilde in paths.""" + self.clear_env_vars() + + os.environ["LLAMA_STACK_CONFIG_DIR"] = "~/custom_llama" + + from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR + + # Should expand tilde + expected = Path.home() / "custom_llama" + self.assertEqual(LLAMA_STACK_CONFIG_DIR, expected) + + def test_config_dirs_empty_environment_variables(self): + """Test config_dirs with empty environment variables.""" + self.clear_env_vars() + + # Set empty values + os.environ["XDG_CONFIG_HOME"] = "" + os.environ["XDG_DATA_HOME"] = "" + + # Mock no legacy directory + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = Path("/home/testuser") + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + from llama_stack.distribution.utils.config_dirs import ( + DEFAULT_CHECKPOINT_DIR, + LLAMA_STACK_CONFIG_DIR, + ) + + # Should fall back to defaults + self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/home/testuser/.config/llama-stack")) + self.assertEqual(DEFAULT_CHECKPOINT_DIR, Path("/home/testuser/.local/share/llama-stack/checkpoints")) + + def test_config_dirs_relative_paths(self): + """Test config_dirs with relative paths in environment variables.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + os.chdir(temp_dir) + + # Use relative path + os.environ["LLAMA_STACK_CONFIG_DIR"] = "relative/config" + + from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR + + # Should be resolved to absolute path + self.assertTrue(LLAMA_STACK_CONFIG_DIR.is_absolute()) + self.assertTrue(str(LLAMA_STACK_CONFIG_DIR).endswith("relative/config")) + + def test_config_dirs_with_spaces_in_paths(self): + """Test config_dirs with spaces in directory paths.""" + self.clear_env_vars() + + path_with_spaces = "/path with spaces/llama config" + os.environ["LLAMA_STACK_CONFIG_DIR"] = path_with_spaces + + from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR + + self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path(path_with_spaces)) + + def test_config_dirs_unicode_paths(self): + """Test config_dirs with unicode characters in paths.""" + self.clear_env_vars() + + unicode_path = "/配置/llama-stack" + os.environ["LLAMA_STACK_CONFIG_DIR"] = unicode_path + + from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR + + self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path(unicode_path)) + + def test_config_dirs_compatibility_import(self): + """Test that config_dirs can be imported without errors in various scenarios.""" + self.clear_env_vars() + + # Test import with no environment variables + try: + # If we get here without exception, the import succeeded + self.assertTrue(True) + except Exception as e: + self.fail(f"Import failed: {e}") + + def test_config_dirs_multiple_imports(self): + """Test that multiple imports of config_dirs return consistent results.""" + self.clear_env_vars() + + os.environ["LLAMA_STACK_CONFIG_DIR"] = "/consistent/path" + + # First import + from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as FIRST_IMPORT + + # Second import (should get cached result) + from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as SECOND_IMPORT + + self.assertEqual(FIRST_IMPORT, SECOND_IMPORT) + self.assertIs(FIRST_IMPORT, SECOND_IMPORT) # Should be the same object + + +class TestConfigDirsIntegration(unittest.TestCase): + """Integration tests for config_dirs with other modules.""" + + def setUp(self): + """Set up test environment.""" + self.original_env = {} + for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]: + self.original_env[key] = os.environ.get(key) + + def tearDown(self): + """Clean up test environment.""" + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + # Clear module cache + import sys + + modules_to_clear = ["llama_stack.distribution.utils.config_dirs", "llama_stack.distribution.utils.xdg_utils"] + for module in modules_to_clear: + if module in sys.modules: + del sys.modules[module] + + def test_config_dirs_with_model_utils(self): + """Test that config_dirs works correctly with model_utils.""" + for key in self.original_env: + os.environ.pop(key, None) + + from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR + from llama_stack.distribution.utils.model_utils import model_local_dir + + # Test that model_local_dir uses the correct base directory + model_descriptor = "meta-llama/Llama-3.2-1B-Instruct" + expected_path = str(DEFAULT_CHECKPOINT_DIR / model_descriptor.replace(":", "-")) + actual_path = model_local_dir(model_descriptor) + + self.assertEqual(actual_path, expected_path) + + def test_config_dirs_consistency_across_modules(self): + """Test that all modules use consistent directory paths.""" + for key in self.original_env: + os.environ.pop(key, None) + + from llama_stack.distribution.utils.config_dirs import ( + DEFAULT_CHECKPOINT_DIR, + LLAMA_STACK_CONFIG_DIR, + RUNTIME_BASE_DIR, + ) + from llama_stack.distribution.utils.xdg_utils import ( + get_llama_stack_config_dir, + get_llama_stack_data_dir, + get_llama_stack_state_dir, + ) + + # Paths should be consistent between modules + self.assertEqual(LLAMA_STACK_CONFIG_DIR, get_llama_stack_config_dir()) + self.assertEqual(DEFAULT_CHECKPOINT_DIR.parent, get_llama_stack_data_dir()) + self.assertEqual(RUNTIME_BASE_DIR.parent, get_llama_stack_state_dir()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_template_xdg_paths.py b/tests/unit/test_template_xdg_paths.py new file mode 100644 index 000000000..d27a23756 --- /dev/null +++ b/tests/unit/test_template_xdg_paths.py @@ -0,0 +1,422 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +import yaml + +# Template imports will be tested through file system access + + +class TestTemplateXDGPaths(unittest.TestCase): + """Test that templates use XDG-compliant paths correctly.""" + + def setUp(self): + """Set up test environment.""" + self.original_env = {} + self.env_vars = [ + "XDG_CONFIG_HOME", + "XDG_DATA_HOME", + "XDG_STATE_HOME", + "XDG_CACHE_HOME", + "LLAMA_STACK_CONFIG_DIR", + "SQLITE_STORE_DIR", + "FILES_STORAGE_DIR", + ] + + for key in self.env_vars: + self.original_env[key] = os.environ.get(key) + + def tearDown(self): + """Clean up test environment.""" + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def clear_env_vars(self): + """Clear all relevant environment variables.""" + for key in self.env_vars: + os.environ.pop(key, None) + + def test_ollama_template_run_yaml_xdg_paths(self): + """Test that ollama template's run.yaml uses XDG environment variables.""" + template_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "run.yaml" + + if not template_path.exists(): + self.skipTest("Ollama template not found") + + content = template_path.read_text() + + # Check for XDG-compliant environment variable references + self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}", content) + self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}", content) + + # Check that paths use llama-stack directory + self.assertIn("llama-stack", content) + + # Check specific path patterns + self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama", content) + self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama", content) + + def test_ollama_template_run_yaml_parsing(self): + """Test that ollama template's run.yaml can be parsed correctly.""" + template_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "run.yaml" + + if not template_path.exists(): + self.skipTest("Ollama template not found") + + content = template_path.read_text() + + # Replace environment variables with test values for parsing + test_content = ( + content.replace("${env.XDG_STATE_HOME:-~/.local/state}", "/test/state") + .replace("${env.XDG_DATA_HOME:-~/.local/share}", "/test/data") + .replace( + "${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}", + "/test/state/llama-stack/distributions/ollama", + ) + ) + + # Should be valid YAML + try: + yaml.safe_load(test_content) + except yaml.YAMLError as e: + self.fail(f"Template YAML is invalid: {e}") + + def test_template_environment_variable_expansion(self): + """Test environment variable expansion in templates.""" + self.clear_env_vars() + + # Set XDG variables + os.environ["XDG_STATE_HOME"] = "/custom/state" + os.environ["XDG_DATA_HOME"] = "/custom/data" + + # Test pattern that should expand + pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test" + expected = "/custom/state/llama-stack/test" + + # Mock environment variable expansion (this would normally be done by the shell) + expanded = pattern.replace("${env.XDG_STATE_HOME:-~/.local/state}", os.environ["XDG_STATE_HOME"]) + self.assertEqual(expanded, expected) + + def test_template_fallback_values(self): + """Test that templates have correct fallback values.""" + self.clear_env_vars() + + # Test fallback pattern + pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test" + + # When environment variable is not set, should use fallback + if "XDG_STATE_HOME" not in os.environ: + # This is what the shell would do + expanded = pattern.replace("${env.XDG_STATE_HOME:-~/.local/state}", "~/.local/state") + expected = "~/.local/state/llama-stack/test" + self.assertEqual(expanded, expected) + + def test_ollama_template_python_config_xdg(self): + """Test that ollama template's Python config uses XDG-compliant paths.""" + template_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "ollama.py" + + if not template_path.exists(): + self.skipTest("Ollama template Python file not found") + + content = template_path.read_text() + + # Check for XDG-compliant environment variable references + self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}", content) + self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}", content) + + # Check that paths use llama-stack directory + self.assertIn("llama-stack", content) + + def test_template_path_consistency(self): + """Test that template paths are consistent across different files.""" + ollama_yaml_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "run.yaml" + ollama_py_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "ollama.py" + + if not ollama_yaml_path.exists() or not ollama_py_path.exists(): + self.skipTest("Ollama template files not found") + + yaml_content = ollama_yaml_path.read_text() + py_content = ollama_py_path.read_text() + + # Both should use the same XDG environment variable patterns + xdg_patterns = ["${env.XDG_STATE_HOME:-~/.local/state}", "${env.XDG_DATA_HOME:-~/.local/share}", "llama-stack"] + + for pattern in xdg_patterns: + self.assertIn(pattern, yaml_content, f"Pattern {pattern} not found in YAML") + self.assertIn(pattern, py_content, f"Pattern {pattern} not found in Python") + + def test_template_no_hardcoded_legacy_paths(self): + """Test that templates don't contain hardcoded legacy paths.""" + template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates" + + if not template_dir.exists(): + self.skipTest("Templates directory not found") + + # Check various template files + for template_path in template_dir.rglob("*.yaml"): + content = template_path.read_text() + + # Should not contain hardcoded ~/.llama paths + self.assertNotIn("~/.llama", content, f"Found hardcoded ~/.llama in {template_path}") + + # Should not contain hardcoded /tmp paths for persistent data + if "db_path" in content or "storage_dir" in content: + self.assertNotIn("/tmp", content, f"Found hardcoded /tmp in {template_path}") + + def test_template_environment_variable_format(self): + """Test that templates use correct environment variable format.""" + template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates" + + if not template_dir.exists(): + self.skipTest("Templates directory not found") + + # Pattern for XDG environment variables with fallbacks + xdg_patterns = [ + "${env.XDG_CONFIG_HOME:-~/.config}", + "${env.XDG_DATA_HOME:-~/.local/share}", + "${env.XDG_STATE_HOME:-~/.local/state}", + "${env.XDG_CACHE_HOME:-~/.cache}", + ] + + for template_path in template_dir.rglob("*.yaml"): + content = template_path.read_text() + + # If XDG variables are used, they should have proper fallbacks + for pattern in xdg_patterns: + base_var = pattern.split(":-")[0] + "}" + if base_var in content: + self.assertIn(pattern, content, f"XDG variable without fallback in {template_path}") + + def test_template_sqlite_store_dir_xdg(self): + """Test that SQLITE_STORE_DIR uses XDG-compliant fallback.""" + template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates" + + if not template_dir.exists(): + self.skipTest("Templates directory not found") + + for template_path in template_dir.rglob("*.yaml"): + content = template_path.read_text() + + if "SQLITE_STORE_DIR" in content: + # Should use XDG fallback pattern + self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}", content) + self.assertIn("llama-stack", content) + + def test_template_files_storage_dir_xdg(self): + """Test that FILES_STORAGE_DIR uses XDG-compliant fallback.""" + template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates" + + if not template_dir.exists(): + self.skipTest("Templates directory not found") + + for template_path in template_dir.rglob("*.yaml"): + content = template_path.read_text() + + if "FILES_STORAGE_DIR" in content: + # Should use XDG fallback pattern + self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}", content) + self.assertIn("llama-stack", content) + + +class TestTemplateCodeGeneration(unittest.TestCase): + """Test template code generation with XDG paths.""" + + def setUp(self): + """Set up test environment.""" + self.original_env = {} + for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]: + self.original_env[key] = os.environ.get(key) + + def tearDown(self): + """Clean up test environment.""" + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def test_provider_codegen_xdg_paths(self): + """Test that provider code generation uses XDG-compliant paths.""" + codegen_path = Path(__file__).parent.parent.parent / "scripts" / "provider_codegen.py" + + if not codegen_path.exists(): + self.skipTest("Provider codegen script not found") + + content = codegen_path.read_text() + + # Should use XDG-compliant path in documentation + self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}/llama-stack", content) + + # Should not use hardcoded ~/.llama paths + self.assertNotIn("~/.llama/dummy", content) + + def test_template_sample_config_paths(self): + """Test that template sample configs use XDG-compliant paths.""" + # This test checks that when templates generate sample configs, + # they use XDG-compliant paths + + # Mock a template that generates sample config + with patch("llama_stack.templates.template.Template") as mock_template: + mock_instance = MagicMock() + mock_template.return_value = mock_instance + + # Mock sample config generation + def mock_sample_config(distro_dir): + # Should use XDG-compliant path structure + self.assertIn("llama-stack", distro_dir) + return {"config": "test"} + + mock_instance.sample_run_config = mock_sample_config + + # Test sample config generation + template = mock_template() + template.sample_run_config("${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/test") + + def test_template_path_substitution(self): + """Test that template path substitution works correctly.""" + # Test path substitution in template generation + + original_path = "~/.llama/distributions/test" + + # Should be converted to XDG-compliant path + xdg_path = original_path.replace("~/.llama", "${env.XDG_DATA_HOME:-~/.local/share}/llama-stack") + expected = "${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/test" + + self.assertEqual(xdg_path, expected) + + def test_template_environment_variable_precedence(self): + """Test environment variable precedence in templates.""" + # Test that custom XDG variables take precedence over defaults + + test_cases = [ + { + "env": {"XDG_STATE_HOME": "/custom/state"}, + "pattern": "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test", + "expected": "/custom/state/llama-stack/test", + }, + { + "env": {}, # No XDG variable set + "pattern": "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test", + "expected": "~/.local/state/llama-stack/test", + }, + ] + + for case in test_cases: + # Clear environment + for key in ["XDG_STATE_HOME", "XDG_DATA_HOME", "XDG_CONFIG_HOME"]: + os.environ.pop(key, None) + + # Set test environment + for key, value in case["env"].items(): + os.environ[key] = value + + # Simulate shell variable expansion + pattern = case["pattern"] + for key, value in case["env"].items(): + var_pattern = f"${{env.{key}:-" + if var_pattern in pattern: + # Replace with actual value + pattern = pattern.replace(f"${{env.{key}:-~/.local/state}}", value) + + # If no replacement happened, use fallback + if "${env.XDG_STATE_HOME:-~/.local/state}" in pattern: + pattern = pattern.replace("${env.XDG_STATE_HOME:-~/.local/state}", "~/.local/state") + + self.assertEqual(pattern, case["expected"]) + + +class TestTemplateIntegration(unittest.TestCase): + """Integration tests for templates with XDG compliance.""" + + def setUp(self): + """Set up test environment.""" + self.original_env = {} + for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]: + self.original_env[key] = os.environ.get(key) + + def tearDown(self): + """Clean up test environment.""" + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def test_template_with_xdg_environment(self): + """Test template behavior with XDG environment variables set.""" + # Clear environment + for key in self.original_env: + os.environ.pop(key, None) + + # Set custom XDG variables + os.environ["XDG_CONFIG_HOME"] = "/custom/config" + os.environ["XDG_DATA_HOME"] = "/custom/data" + os.environ["XDG_STATE_HOME"] = "/custom/state" + + # Test that template paths would resolve correctly + # (This is a conceptual test since actual shell expansion happens at runtime) + + template_pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test" + + # In a real shell, this would expand to: + + # Verify the pattern structure is correct + self.assertIn("XDG_STATE_HOME", template_pattern) + self.assertIn("llama-stack", template_pattern) + self.assertIn("~/.local/state", template_pattern) # fallback + + def test_template_with_no_xdg_environment(self): + """Test template behavior with no XDG environment variables.""" + # Clear all XDG environment variables + for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "XDG_CACHE_HOME"]: + os.environ.pop(key, None) + + # Test that templates would use fallback values + template_pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test" + + # In a real shell with no XDG_STATE_HOME, this would expand to: + + # Verify the pattern structure includes fallback + self.assertIn(":-~/.local/state", template_pattern) + + def test_template_consistency_across_providers(self): + """Test that all template providers use consistent XDG patterns.""" + templates_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates" + + if not templates_dir.exists(): + self.skipTest("Templates directory not found") + + # Expected XDG patterns that should be consistent across templates + + # Check a few different provider templates + provider_templates = [] + for provider_dir in templates_dir.iterdir(): + if provider_dir.is_dir() and not provider_dir.name.startswith("."): + run_yaml = provider_dir / "run.yaml" + if run_yaml.exists(): + provider_templates.append(run_yaml) + + if not provider_templates: + self.skipTest("No provider templates found") + + # Check that templates use consistent patterns + for template_path in provider_templates[:3]: # Check first 3 templates + content = template_path.read_text() + + # Should use llama-stack in paths + if any(xdg_var in content for xdg_var in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME"]): + self.assertIn("llama-stack", content, f"Template {template_path} uses XDG but not llama-stack") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_xdg_compliance.py b/tests/unit/test_xdg_compliance.py new file mode 100644 index 000000000..cd2b99cc8 --- /dev/null +++ b/tests/unit/test_xdg_compliance.py @@ -0,0 +1,610 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from llama_stack.distribution.utils.xdg_utils import ( + ensure_directory_exists, + get_llama_stack_cache_dir, + get_llama_stack_config_dir, + get_llama_stack_data_dir, + get_llama_stack_state_dir, + get_xdg_cache_home, + get_xdg_compliant_path, + get_xdg_config_home, + get_xdg_data_home, + get_xdg_state_home, + migrate_legacy_directory, +) + + +class TestXDGCompliance(unittest.TestCase): + """Comprehensive test suite for XDG Base Directory Specification compliance.""" + + def setUp(self): + """Set up test environment.""" + # Store original environment variables + self.original_env = {} + self.xdg_vars = ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_CACHE_HOME", "XDG_STATE_HOME"] + self.llama_vars = ["LLAMA_STACK_CONFIG_DIR", "SQLITE_STORE_DIR", "FILES_STORAGE_DIR"] + + for key in self.xdg_vars + self.llama_vars: + self.original_env[key] = os.environ.get(key) + + def tearDown(self): + """Clean up test environment.""" + # Restore original environment variables + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def clear_env_vars(self, vars_to_clear=None): + """Helper to clear environment variables.""" + if vars_to_clear is None: + vars_to_clear = self.xdg_vars + self.llama_vars + + for key in vars_to_clear: + os.environ.pop(key, None) + + def test_xdg_defaults(self): + """Test that XDG directories use correct defaults when no env vars are set.""" + self.clear_env_vars() + home = Path.home() + + self.assertEqual(get_xdg_config_home(), home / ".config") + self.assertEqual(get_xdg_data_home(), home / ".local" / "share") + self.assertEqual(get_xdg_cache_home(), home / ".cache") + self.assertEqual(get_xdg_state_home(), home / ".local" / "state") + + def test_xdg_custom_paths(self): + """Test that custom XDG paths are respected.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + os.environ["XDG_CONFIG_HOME"] = str(temp_path / "config") + os.environ["XDG_DATA_HOME"] = str(temp_path / "data") + os.environ["XDG_CACHE_HOME"] = str(temp_path / "cache") + os.environ["XDG_STATE_HOME"] = str(temp_path / "state") + + self.assertEqual(get_xdg_config_home(), temp_path / "config") + self.assertEqual(get_xdg_data_home(), temp_path / "data") + self.assertEqual(get_xdg_cache_home(), temp_path / "cache") + self.assertEqual(get_xdg_state_home(), temp_path / "state") + + def test_xdg_paths_with_tilde(self): + """Test XDG paths that use tilde expansion.""" + os.environ["XDG_CONFIG_HOME"] = "~/custom_config" + os.environ["XDG_DATA_HOME"] = "~/custom_data" + + home = Path.home() + self.assertEqual(get_xdg_config_home(), home / "custom_config") + self.assertEqual(get_xdg_data_home(), home / "custom_data") + + def test_xdg_paths_relative(self): + """Test XDG paths with relative paths get resolved.""" + with tempfile.TemporaryDirectory() as temp_dir: + os.chdir(temp_dir) + os.environ["XDG_CONFIG_HOME"] = "relative_config" + + # Should resolve relative to current directory + result = get_xdg_config_home() + self.assertTrue(result.is_absolute()) + self.assertTrue(str(result).endswith("relative_config")) + + def test_llama_stack_directories_new_installation(self): + """Test llama-stack directories for new installations (no legacy directory).""" + self.clear_env_vars() + home = Path.home() + + # Mock that ~/.llama doesn't exist + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = home + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + self.assertEqual(get_llama_stack_config_dir(), home / ".config" / "llama-stack") + self.assertEqual(get_llama_stack_data_dir(), home / ".local" / "share" / "llama-stack") + self.assertEqual(get_llama_stack_state_dir(), home / ".local" / "state" / "llama-stack") + self.assertEqual(get_llama_stack_cache_dir(), home / ".cache" / "llama-stack") + + def test_llama_stack_directories_with_custom_xdg(self): + """Test llama-stack directories with custom XDG paths.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + os.environ["XDG_CONFIG_HOME"] = str(temp_path / "config") + os.environ["XDG_DATA_HOME"] = str(temp_path / "data") + os.environ["XDG_STATE_HOME"] = str(temp_path / "state") + os.environ["XDG_CACHE_HOME"] = str(temp_path / "cache") + + # Mock that ~/.llama doesn't exist + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + self.assertEqual(get_llama_stack_config_dir(), temp_path / "config" / "llama-stack") + self.assertEqual(get_llama_stack_data_dir(), temp_path / "data" / "llama-stack") + self.assertEqual(get_llama_stack_state_dir(), temp_path / "state" / "llama-stack") + self.assertEqual(get_llama_stack_cache_dir(), temp_path / "cache" / "llama-stack") + + def test_legacy_environment_variable_precedence(self): + """Test that LLAMA_STACK_CONFIG_DIR takes highest precedence.""" + with tempfile.TemporaryDirectory() as temp_dir: + legacy_path = Path(temp_dir) / "legacy" + xdg_path = Path(temp_dir) / "xdg" + + # Set both legacy and XDG variables + os.environ["LLAMA_STACK_CONFIG_DIR"] = str(legacy_path) + os.environ["XDG_CONFIG_HOME"] = str(xdg_path / "config") + os.environ["XDG_DATA_HOME"] = str(xdg_path / "data") + os.environ["XDG_STATE_HOME"] = str(xdg_path / "state") + + # Legacy should take precedence for all directory types + self.assertEqual(get_llama_stack_config_dir(), legacy_path) + self.assertEqual(get_llama_stack_data_dir(), legacy_path) + self.assertEqual(get_llama_stack_state_dir(), legacy_path) + self.assertEqual(get_llama_stack_cache_dir(), legacy_path) + + def test_legacy_directory_exists_and_has_content(self): + """Test that existing ~/.llama directory with content is used.""" + with tempfile.TemporaryDirectory() as temp_dir: + home = Path(temp_dir) + legacy_llama = home / ".llama" + legacy_llama.mkdir() + + # Create some content to simulate existing data + (legacy_llama / "test_file").touch() + (legacy_llama / "distributions").mkdir() + + # Clear environment variables + self.clear_env_vars() + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = home + + self.assertEqual(get_llama_stack_config_dir(), legacy_llama) + self.assertEqual(get_llama_stack_data_dir(), legacy_llama) + self.assertEqual(get_llama_stack_state_dir(), legacy_llama) + + def test_legacy_directory_exists_but_empty(self): + """Test that empty ~/.llama directory is ignored in favor of XDG.""" + with tempfile.TemporaryDirectory() as temp_dir: + home = Path(temp_dir) + legacy_llama = home / ".llama" + legacy_llama.mkdir() + # Don't add any content - directory is empty + + self.clear_env_vars() + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = home + + # Should use XDG paths since legacy directory is empty + self.assertEqual(get_llama_stack_config_dir(), home / ".config" / "llama-stack") + self.assertEqual(get_llama_stack_data_dir(), home / ".local" / "share" / "llama-stack") + self.assertEqual(get_llama_stack_state_dir(), home / ".local" / "state" / "llama-stack") + + def test_xdg_compliant_path_function(self): + """Test the get_xdg_compliant_path utility function.""" + self.clear_env_vars() + home = Path.home() + + # Mock that ~/.llama doesn't exist + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = home + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + self.assertEqual(get_xdg_compliant_path("config"), home / ".config" / "llama-stack") + self.assertEqual( + get_xdg_compliant_path("data", "models"), home / ".local" / "share" / "llama-stack" / "models" + ) + self.assertEqual( + get_xdg_compliant_path("state", "runtime"), home / ".local" / "state" / "llama-stack" / "runtime" + ) + self.assertEqual(get_xdg_compliant_path("cache", "temp"), home / ".cache" / "llama-stack" / "temp") + + def test_xdg_compliant_path_invalid_type(self): + """Test that invalid path types raise ValueError.""" + with self.assertRaises(ValueError) as context: + get_xdg_compliant_path("invalid_type") + + self.assertIn("Unknown path type", str(context.exception)) + self.assertIn("invalid_type", str(context.exception)) + + def test_xdg_compliant_path_with_subdirectory(self): + """Test get_xdg_compliant_path with various subdirectories.""" + self.clear_env_vars() + home = Path.home() + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = home + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + # Test nested subdirectories + self.assertEqual( + get_xdg_compliant_path("data", "models/checkpoints"), + home / ".local" / "share" / "llama-stack" / "models/checkpoints", + ) + + # Test with Path object + self.assertEqual( + get_xdg_compliant_path("config", str(Path("distributions") / "ollama")), + home / ".config" / "llama-stack" / "distributions" / "ollama", + ) + + def test_ensure_directory_exists(self): + """Test the ensure_directory_exists utility function.""" + with tempfile.TemporaryDirectory() as temp_dir: + test_path = Path(temp_dir) / "nested" / "directory" / "structure" + + # Directory shouldn't exist initially + self.assertFalse(test_path.exists()) + + # Create it + ensure_directory_exists(test_path) + + # Should exist now + self.assertTrue(test_path.exists()) + self.assertTrue(test_path.is_dir()) + + def test_ensure_directory_exists_already_exists(self): + """Test ensure_directory_exists when directory already exists.""" + with tempfile.TemporaryDirectory() as temp_dir: + test_path = Path(temp_dir) / "existing" + test_path.mkdir() + + # Should not raise an error + ensure_directory_exists(test_path) + self.assertTrue(test_path.exists()) + + def test_config_dirs_import_and_types(self): + """Test that the config_dirs module imports correctly and has proper types.""" + from llama_stack.distribution.utils.config_dirs import ( + DEFAULT_CHECKPOINT_DIR, + DISTRIBS_BASE_DIR, + EXTERNAL_PROVIDERS_DIR, + LLAMA_STACK_CONFIG_DIR, + RUNTIME_BASE_DIR, + ) + + # All should be Path objects + self.assertIsInstance(LLAMA_STACK_CONFIG_DIR, Path) + self.assertIsInstance(DEFAULT_CHECKPOINT_DIR, Path) + self.assertIsInstance(RUNTIME_BASE_DIR, Path) + self.assertIsInstance(EXTERNAL_PROVIDERS_DIR, Path) + self.assertIsInstance(DISTRIBS_BASE_DIR, Path) + + # All should be absolute paths + self.assertTrue(LLAMA_STACK_CONFIG_DIR.is_absolute()) + self.assertTrue(DEFAULT_CHECKPOINT_DIR.is_absolute()) + self.assertTrue(RUNTIME_BASE_DIR.is_absolute()) + self.assertTrue(EXTERNAL_PROVIDERS_DIR.is_absolute()) + self.assertTrue(DISTRIBS_BASE_DIR.is_absolute()) + + def test_config_dirs_proper_structure(self): + """Test that config_dirs uses proper XDG structure.""" + from llama_stack.distribution.utils.config_dirs import ( + DISTRIBS_BASE_DIR, + EXTERNAL_PROVIDERS_DIR, + LLAMA_STACK_CONFIG_DIR, + ) + + # Check that paths contain expected components + config_str = str(LLAMA_STACK_CONFIG_DIR) + self.assertTrue( + "llama-stack" in config_str or ".llama" in config_str, + f"Config dir should contain 'llama-stack' or '.llama': {config_str}", + ) + + # Test relationships between directories + self.assertEqual(DISTRIBS_BASE_DIR, LLAMA_STACK_CONFIG_DIR / "distributions") + self.assertEqual(EXTERNAL_PROVIDERS_DIR, LLAMA_STACK_CONFIG_DIR / "providers.d") + + def test_environment_variable_combinations(self): + """Test various combinations of environment variables.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Test partial XDG variables + os.environ["XDG_CONFIG_HOME"] = str(temp_path / "config") + # Leave others as default + self.clear_env_vars(["XDG_DATA_HOME", "XDG_STATE_HOME", "XDG_CACHE_HOME"]) + + home = Path.home() + + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + self.assertEqual(get_llama_stack_config_dir(), temp_path / "config" / "llama-stack") + self.assertEqual(get_llama_stack_data_dir(), home / ".local" / "share" / "llama-stack") + self.assertEqual(get_llama_stack_state_dir(), home / ".local" / "state" / "llama-stack") + + def test_migrate_legacy_directory_no_legacy(self): + """Test migration when no legacy directory exists.""" + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = Path("/fake/home") + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + # Should return True (success) when no migration needed + result = migrate_legacy_directory() + self.assertTrue(result) + + def test_migrate_legacy_directory_exists(self): + """Test migration message when legacy directory exists.""" + with tempfile.TemporaryDirectory() as temp_dir: + home = Path(temp_dir) + legacy_llama = home / ".llama" + legacy_llama.mkdir() + (legacy_llama / "test_file").touch() + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = home + with patch("builtins.print") as mock_print: + result = migrate_legacy_directory() + self.assertTrue(result) + + # Check that migration information was printed + print_calls = [call[0][0] for call in mock_print.call_args_list] + self.assertTrue(any("Found legacy directory" in call for call in print_calls)) + self.assertTrue(any("Consider migrating" in call for call in print_calls)) + + def test_path_consistency_across_functions(self): + """Test that all path functions return consistent results.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + home = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = home + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + # All config-related functions should return the same base + config_dir = get_llama_stack_config_dir() + config_path = get_xdg_compliant_path("config") + self.assertEqual(config_dir, config_path) + + # All data-related functions should return the same base + data_dir = get_llama_stack_data_dir() + data_path = get_xdg_compliant_path("data") + self.assertEqual(data_dir, data_path) + + def test_unicode_and_special_characters(self): + """Test XDG paths with unicode and special characters.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Test with unicode characters + unicode_path = Path(temp_dir) / "配置" / "llama-stack" + os.environ["XDG_CONFIG_HOME"] = str(unicode_path.parent) + + result = get_xdg_config_home() + self.assertEqual(result, unicode_path.parent) + + # Test spaces in paths + space_path = Path(temp_dir) / "my config" + os.environ["XDG_CONFIG_HOME"] = str(space_path) + + result = get_xdg_config_home() + self.assertEqual(result, space_path) + + def test_concurrent_access_safety(self): + """Test that XDG functions are safe for concurrent access.""" + import threading + import time + + results = [] + errors = [] + + def worker(): + try: + # Simulate concurrent access + config_dir = get_llama_stack_config_dir() + time.sleep(0.01) # Small delay to increase chance of race conditions + data_dir = get_llama_stack_data_dir() + results.append((config_dir, data_dir)) + except Exception as e: + errors.append(e) + + # Start multiple threads + threads = [] + for _ in range(10): + t = threading.Thread(target=worker) + threads.append(t) + t.start() + + # Wait for all threads + for t in threads: + t.join() + + # Check results + self.assertEqual(len(errors), 0, f"Concurrent access errors: {errors}") + self.assertEqual(len(results), 10) + + # All results should be identical + first_result = results[0] + for result in results[1:]: + self.assertEqual(result, first_result) + + def test_symlink_handling(self): + """Test XDG path handling with symbolic links.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create actual directory + actual_dir = temp_path / "actual_config" + actual_dir.mkdir() + + # Create symlink + symlink_dir = temp_path / "symlinked_config" + symlink_dir.symlink_to(actual_dir) + + os.environ["XDG_CONFIG_HOME"] = str(symlink_dir) + + result = get_xdg_config_home() + self.assertEqual(result, symlink_dir) + + # Should resolve to actual path when needed + self.assertTrue(result.exists()) + + def test_readonly_directory_handling(self): + """Test behavior when XDG directories are read-only.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + readonly_dir = temp_path / "readonly" + readonly_dir.mkdir() + + # Make directory read-only + readonly_dir.chmod(0o444) + + try: + os.environ["XDG_CONFIG_HOME"] = str(readonly_dir) + + # Should still return the path even if read-only + result = get_xdg_config_home() + self.assertEqual(result, readonly_dir) + + finally: + # Restore permissions for cleanup + readonly_dir.chmod(0o755) + + def test_nonexistent_parent_directory(self): + """Test XDG paths with non-existent parent directories.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Use a path with non-existent parents + nonexistent_path = Path(temp_dir) / "does" / "not" / "exist" / "config" + + os.environ["XDG_CONFIG_HOME"] = str(nonexistent_path) + + # Should return the path even if it doesn't exist + result = get_xdg_config_home() + self.assertEqual(result, nonexistent_path) + + def test_env_var_expansion(self): + """Test environment variable expansion in XDG paths.""" + with tempfile.TemporaryDirectory() as temp_dir: + os.environ["TEST_BASE"] = temp_dir + os.environ["XDG_CONFIG_HOME"] = "$TEST_BASE/config" + + # Path expansion should work + result = get_xdg_config_home() + expected = Path(temp_dir) / "config" + self.assertEqual(result, expected) + + +class TestXDGEdgeCases(unittest.TestCase): + """Test edge cases and error conditions for XDG compliance.""" + + def setUp(self): + """Set up test environment.""" + self.original_env = {} + for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_CACHE_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]: + self.original_env[key] = os.environ.get(key) + + def tearDown(self): + """Clean up test environment.""" + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def test_empty_environment_variables(self): + """Test behavior with empty environment variables.""" + # Set empty values + os.environ["XDG_CONFIG_HOME"] = "" + os.environ["XDG_DATA_HOME"] = "" + + # Should fall back to defaults + home = Path.home() + self.assertEqual(get_xdg_config_home(), home / ".config") + self.assertEqual(get_xdg_data_home(), home / ".local" / "share") + + def test_whitespace_only_environment_variables(self): + """Test behavior with whitespace-only environment variables.""" + os.environ["XDG_CONFIG_HOME"] = " " + os.environ["XDG_DATA_HOME"] = "\t\n" + + # Should handle whitespace gracefully + result_config = get_xdg_config_home() + result_data = get_xdg_data_home() + + # Results should be valid Path objects + self.assertIsInstance(result_config, Path) + self.assertIsInstance(result_data, Path) + + def test_very_long_paths(self): + """Test behavior with very long directory paths.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create a very long path + long_path_parts = ["very_long_directory_name_" + str(i) for i in range(20)] + long_path = Path(temp_dir) + for part in long_path_parts: + long_path = long_path / part + + os.environ["XDG_CONFIG_HOME"] = str(long_path) + + result = get_xdg_config_home() + self.assertEqual(result, long_path) + + def test_circular_symlinks(self): + """Test handling of circular symbolic links.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create circular symlinks + link1 = temp_path / "link1" + link2 = temp_path / "link2" + + try: + link1.symlink_to(link2) + link2.symlink_to(link1) + + os.environ["XDG_CONFIG_HOME"] = str(link1) + + # Should handle circular symlinks gracefully + result = get_xdg_config_home() + self.assertEqual(result, link1) + + except (OSError, NotImplementedError): + # Some systems don't support circular symlinks + self.skipTest("System doesn't support circular symlinks") + + def test_permission_denied_scenarios(self): + """Test scenarios where permission is denied.""" + # This test is platform-specific and may not work on all systems + try: + # Try to use a system directory that typically requires root + os.environ["XDG_CONFIG_HOME"] = "/root/.config" + + # Should still return the path even if we can't access it + result = get_xdg_config_home() + self.assertEqual(result, Path("/root/.config")) + + except Exception: + # If this fails, it's not critical for the XDG implementation + pass + + def test_network_paths(self): + """Test XDG paths with network/UNC paths (Windows-style).""" + # Test UNC path (though this may not work on non-Windows systems) + network_path = "//server/share/config" + os.environ["XDG_CONFIG_HOME"] = network_path + + result = get_xdg_config_home() + # Should handle network paths gracefully + self.assertIsInstance(result, Path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_xdg_edge_cases.py b/tests/unit/test_xdg_edge_cases.py new file mode 100644 index 000000000..671459e9e --- /dev/null +++ b/tests/unit/test_xdg_edge_cases.py @@ -0,0 +1,522 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import platform +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from llama_stack.distribution.utils.xdg_utils import ( + ensure_directory_exists, + get_llama_stack_config_dir, + get_xdg_config_home, + migrate_legacy_directory, +) + + +class TestXDGEdgeCases(unittest.TestCase): + """Test edge cases and error conditions for XDG compliance.""" + + def setUp(self): + """Set up test environment.""" + self.original_env = {} + for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "XDG_CACHE_HOME", "LLAMA_STACK_CONFIG_DIR"]: + self.original_env[key] = os.environ.get(key) + + def tearDown(self): + """Clean up test environment.""" + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def clear_env_vars(self): + """Clear all XDG environment variables.""" + for key in self.original_env: + os.environ.pop(key, None) + + def test_very_long_paths(self): + """Test XDG functions with very long directory paths.""" + self.clear_env_vars() + + # Create a very long path (close to filesystem limits) + long_components = ["very_long_directory_name_" + str(i) for i in range(50)] + long_path = "/tmp/" + "/".join(long_components) + + # Test with very long XDG paths + os.environ["XDG_CONFIG_HOME"] = long_path + + result = get_xdg_config_home() + self.assertEqual(result, Path(long_path)) + + # Should handle long paths in llama-stack functions + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + config_dir = get_llama_stack_config_dir() + self.assertEqual(config_dir, Path(long_path) / "llama-stack") + + def test_paths_with_special_characters(self): + """Test XDG functions with special characters in paths.""" + self.clear_env_vars() + + # Test various special characters + special_chars = [ + "path with spaces", + "path-with-hyphens", + "path_with_underscores", + "path.with.dots", + "path@with@symbols", + "path+with+plus", + "path&with&ersand", + "path(with)parentheses", + ] + + for special_path in special_chars: + with self.subTest(path=special_path): + test_path = f"/tmp/{special_path}" + os.environ["XDG_CONFIG_HOME"] = test_path + + result = get_xdg_config_home() + self.assertEqual(result, Path(test_path)) + + def test_unicode_paths(self): + """Test XDG functions with unicode characters in paths.""" + self.clear_env_vars() + + unicode_paths = [ + "/配置/llama-stack", # Chinese + "/конфигурация/llama-stack", # Russian + "/構成/llama-stack", # Japanese + "/구성/llama-stack", # Korean + "/تكوين/llama-stack", # Arabic + "/configuración/llama-stack", # Spanish with accents + "/配置📁/llama-stack", # With emoji + ] + + for unicode_path in unicode_paths: + with self.subTest(path=unicode_path): + os.environ["XDG_CONFIG_HOME"] = unicode_path + + result = get_xdg_config_home() + self.assertEqual(result, Path(unicode_path)) + + def test_network_paths(self): + """Test XDG functions with network/UNC paths.""" + self.clear_env_vars() + + if platform.system() == "Windows": + # Test Windows UNC paths + unc_paths = [ + "\\\\server\\share\\config", + "\\\\server.domain.com\\share\\config", + "\\\\192.168.1.100\\config", + ] + + for unc_path in unc_paths: + with self.subTest(path=unc_path): + os.environ["XDG_CONFIG_HOME"] = unc_path + + result = get_xdg_config_home() + self.assertEqual(result, Path(unc_path)) + else: + # Test network mount paths on Unix-like systems + network_paths = [ + "/mnt/nfs/config", + "/net/server/config", + "/media/network/config", + ] + + for network_path in network_paths: + with self.subTest(path=network_path): + os.environ["XDG_CONFIG_HOME"] = network_path + + result = get_xdg_config_home() + self.assertEqual(result, Path(network_path)) + + def test_nonexistent_paths(self): + """Test XDG functions with non-existent paths.""" + self.clear_env_vars() + + nonexistent_path = "/this/path/does/not/exist/config" + os.environ["XDG_CONFIG_HOME"] = nonexistent_path + + # Should return the path even if it doesn't exist + result = get_xdg_config_home() + self.assertEqual(result, Path(nonexistent_path)) + + # Should work with llama-stack functions too + with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists: + mock_exists.return_value = False + + config_dir = get_llama_stack_config_dir() + self.assertEqual(config_dir, Path(nonexistent_path) / "llama-stack") + + def test_circular_symlinks(self): + """Test XDG functions with circular symbolic links.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create circular symlinks + link1 = temp_path / "link1" + link2 = temp_path / "link2" + + try: + link1.symlink_to(link2) + link2.symlink_to(link1) + + os.environ["XDG_CONFIG_HOME"] = str(link1) + + # Should handle circular symlinks gracefully + result = get_xdg_config_home() + self.assertEqual(result, link1) + + except (OSError, NotImplementedError): + # Some systems don't support circular symlinks + self.skipTest("System doesn't support circular symlinks") + + def test_broken_symlinks(self): + """Test XDG functions with broken symbolic links.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create broken symlink + target = temp_path / "nonexistent_target" + link = temp_path / "broken_link" + + try: + link.symlink_to(target) + + os.environ["XDG_CONFIG_HOME"] = str(link) + + # Should handle broken symlinks gracefully + result = get_xdg_config_home() + self.assertEqual(result, link) + + except (OSError, NotImplementedError): + # Some systems might not support this + self.skipTest("System doesn't support broken symlinks") + + def test_readonly_directories(self): + """Test XDG functions with read-only directories.""" + self.clear_env_vars() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + readonly_dir = temp_path / "readonly" + readonly_dir.mkdir() + + # Make directory read-only + readonly_dir.chmod(0o444) + + try: + os.environ["XDG_CONFIG_HOME"] = str(readonly_dir) + + # Should still return the path + result = get_xdg_config_home() + self.assertEqual(result, readonly_dir) + + finally: + # Restore permissions for cleanup + readonly_dir.chmod(0o755) + + def test_permission_denied_access(self): + """Test XDG functions when permission is denied.""" + self.clear_env_vars() + + # This test is platform-specific + if platform.system() != "Windows": + # Try to use a system directory that typically requires root + restricted_paths = [ + "/root/.config", + "/etc/config", + "/var/root/config", + ] + + for restricted_path in restricted_paths: + with self.subTest(path=restricted_path): + os.environ["XDG_CONFIG_HOME"] = restricted_path + + # Should still return the path even if we can't access it + result = get_xdg_config_home() + self.assertEqual(result, Path(restricted_path)) + + def test_environment_variable_injection(self): + """Test XDG functions with environment variable injection attempts.""" + self.clear_env_vars() + + # Test potential injection attempts + injection_attempts = [ + "/tmp/config; rm -rf /", + "/tmp/config && echo 'injected'", + "/tmp/config | cat /etc/passwd", + "/tmp/config`whoami`", + "/tmp/config$(whoami)", + "/tmp/config\necho 'newline'", + ] + + for injection_attempt in injection_attempts: + with self.subTest(attempt=injection_attempt): + os.environ["XDG_CONFIG_HOME"] = injection_attempt + + # Should treat as literal path, not execute + result = get_xdg_config_home() + self.assertEqual(result, Path(injection_attempt)) + + def test_extremely_nested_paths(self): + """Test XDG functions with extremely nested directory structures.""" + self.clear_env_vars() + + # Create deeply nested path + nested_parts = ["level" + str(i) for i in range(100)] + nested_path = "/tmp/" + "/".join(nested_parts) + + os.environ["XDG_CONFIG_HOME"] = nested_path + + result = get_xdg_config_home() + self.assertEqual(result, Path(nested_path)) + + def test_empty_and_whitespace_paths(self): + """Test XDG functions with empty and whitespace-only paths.""" + self.clear_env_vars() + + empty_values = [ + "", + " ", + "\t", + "\n", + "\r\n", + " \t \n ", + ] + + for empty_value in empty_values: + with self.subTest(value=repr(empty_value)): + os.environ["XDG_CONFIG_HOME"] = empty_value + + # Should fall back to default + result = get_xdg_config_home() + self.assertEqual(result, Path.home() / ".config") + + def test_path_with_null_bytes(self): + """Test XDG functions with null bytes in paths.""" + self.clear_env_vars() + + # Test path with null byte + null_path = "/tmp/config\x00/test" + os.environ["XDG_CONFIG_HOME"] = null_path + + # Should handle null bytes (Path will likely raise an error, which is expected) + try: + result = get_xdg_config_home() + # If it doesn't raise an error, check the result + self.assertIsInstance(result, Path) + except (ValueError, OSError): + # This is expected behavior for null bytes + pass + + def test_concurrent_access_safety(self): + """Test that XDG functions are thread-safe.""" + import threading + import time + + self.clear_env_vars() + + results = [] + errors = [] + + def worker(thread_id): + try: + # Each thread sets a different XDG path + os.environ["XDG_CONFIG_HOME"] = f"/tmp/thread_{thread_id}" + + # Small delay to increase chance of race conditions + time.sleep(0.01) + + config_dir = get_llama_stack_config_dir() + results.append((thread_id, config_dir)) + + except Exception as e: + errors.append((thread_id, e)) + + # Start multiple threads + threads = [] + for i in range(20): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + # Wait for all threads + for t in threads: + t.join() + + # Check for errors + if errors: + self.fail(f"Thread errors: {errors}") + + # Check that we got results from all threads + self.assertEqual(len(results), 20) + + def test_filesystem_limits(self): + """Test XDG functions approaching filesystem limits.""" + self.clear_env_vars() + + # Test with very long filename (close to 255 char limit) + long_filename = "a" * 240 + long_path = f"/tmp/{long_filename}" + + os.environ["XDG_CONFIG_HOME"] = long_path + + result = get_xdg_config_home() + self.assertEqual(result, Path(long_path)) + + def test_case_sensitivity(self): + """Test XDG functions with case sensitivity edge cases.""" + self.clear_env_vars() + + # Test case variations + case_variations = [ + "/tmp/Config", + "/tmp/CONFIG", + "/tmp/config", + "/tmp/Config/MixedCase", + ] + + for case_path in case_variations: + with self.subTest(path=case_path): + os.environ["XDG_CONFIG_HOME"] = case_path + + result = get_xdg_config_home() + self.assertEqual(result, Path(case_path)) + + def test_ensure_directory_exists_edge_cases(self): + """Test ensure_directory_exists with edge cases.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Test with file that exists but is not a directory + file_path = temp_path / "file_not_dir" + file_path.touch() + + with self.assertRaises(FileExistsError): + ensure_directory_exists(file_path) + + # Test with permission denied + if platform.system() != "Windows": + readonly_parent = temp_path / "readonly_parent" + readonly_parent.mkdir() + readonly_parent.chmod(0o444) + + try: + nested_path = readonly_parent / "nested" + + with self.assertRaises(PermissionError): + ensure_directory_exists(nested_path) + + finally: + # Restore permissions for cleanup + readonly_parent.chmod(0o755) + + def test_migrate_legacy_directory_edge_cases(self): + """Test migrate_legacy_directory with edge cases.""" + with tempfile.TemporaryDirectory() as temp_dir: + home_dir = Path(temp_dir) + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = home_dir + + # Test with legacy directory but no write permissions + legacy_dir = home_dir / ".llama" + legacy_dir.mkdir() + (legacy_dir / "test_file").touch() + + # Make home directory read-only + home_dir.chmod(0o444) + + try: + # Should handle permission errors gracefully + with patch("builtins.print") as mock_print: + migrate_legacy_directory() + + # Should print some information + self.assertTrue(mock_print.called) + + finally: + # Restore permissions for cleanup + home_dir.chmod(0o755) + legacy_dir.chmod(0o755) + + def test_path_traversal_attempts(self): + """Test XDG functions with path traversal attempts.""" + self.clear_env_vars() + + traversal_attempts = [ + "/tmp/config/../../../etc/passwd", + "/tmp/config/../../root/.ssh", + "/tmp/config/../../../../../etc/shadow", + "/tmp/config/./../../root", + ] + + for traversal_attempt in traversal_attempts: + with self.subTest(attempt=traversal_attempt): + os.environ["XDG_CONFIG_HOME"] = traversal_attempt + + # Should handle path traversal attempts by treating as literal paths + result = get_xdg_config_home() + self.assertEqual(result, Path(traversal_attempt)) + + def test_environment_variable_precedence_edge_cases(self): + """Test environment variable precedence with edge cases.""" + self.clear_env_vars() + + # Test with both old and new env vars set + os.environ["LLAMA_STACK_CONFIG_DIR"] = "/legacy/path" + os.environ["XDG_CONFIG_HOME"] = "/xdg/path" + + # Create fake legacy directory + with tempfile.TemporaryDirectory() as temp_dir: + fake_home = Path(temp_dir) + fake_legacy = fake_home / ".llama" + fake_legacy.mkdir() + (fake_legacy / "test_file").touch() + + with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home: + mock_home.return_value = fake_home + + # LLAMA_STACK_CONFIG_DIR should take precedence + config_dir = get_llama_stack_config_dir() + self.assertEqual(config_dir, Path("/legacy/path")) + + def test_malformed_environment_variables(self): + """Test XDG functions with malformed environment variables.""" + self.clear_env_vars() + + malformed_values = [ + "not_an_absolute_path", + "~/tilde_not_expanded", + "$HOME/variable_not_expanded", + "relative/path/config", + "./relative/path", + "../parent/path", + ] + + for malformed_value in malformed_values: + with self.subTest(value=malformed_value): + os.environ["XDG_CONFIG_HOME"] = malformed_value + + # Should handle malformed values gracefully + result = get_xdg_config_home() + self.assertIsInstance(result, Path) + + +if __name__ == "__main__": + unittest.main()