mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 21:57:45 +00:00
Merge 407c3e3bad
into 632cf9eb72
This commit is contained in:
commit
172d578b20
50 changed files with 5611 additions and 508 deletions
172
docs/source/distributions/self_hosted_distro/ollama.md
Normal file
172
docs/source/distributions/self_hosted_distro/ollama.md
Normal file
|
@ -0,0 +1,172 @@
|
|||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||
# Ollama Distribution
|
||||
|
||||
The `llamastack/distribution-ollama` distribution consists of the following provider configurations.
|
||||
|
||||
| API | Provider(s) |
|
||||
|-----|-------------|
|
||||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| files | `inline::localfs` |
|
||||
| inference | `remote::ollama` |
|
||||
| post_training | `inline::huggingface` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
|
||||
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
|
||||
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Ollama Server
|
||||
|
||||
This distribution requires an external Ollama server to be running. You can install and run Ollama by following these steps:
|
||||
|
||||
1. **Install Ollama**: Download and install Ollama from [https://ollama.ai/](https://ollama.ai/)
|
||||
|
||||
2. **Start the Ollama server**:
|
||||
```bash
|
||||
ollama serve
|
||||
```
|
||||
By default, Ollama serves on `http://127.0.0.1:11434`
|
||||
|
||||
3. **Pull the required models**:
|
||||
```bash
|
||||
# Pull the inference model
|
||||
ollama pull meta-llama/Llama-3.2-3B-Instruct
|
||||
|
||||
# Pull the embedding model
|
||||
ollama pull all-minilm:latest
|
||||
|
||||
# (Optional) Pull the safety model for run-with-safety.yaml
|
||||
ollama pull meta-llama/Llama-Guard-3-1B
|
||||
```
|
||||
|
||||
## Supported Services
|
||||
|
||||
### Inference: Ollama
|
||||
Uses an external Ollama server for running LLM inference. The server should be accessible at the URL specified in the `OLLAMA_URL` environment variable.
|
||||
|
||||
### Vector IO: FAISS
|
||||
Provides vector storage capabilities using FAISS for embeddings and similarity search operations.
|
||||
|
||||
### Safety: Llama Guard (Optional)
|
||||
When using the `run-with-safety.yaml` configuration, provides safety checks using Llama Guard models running on the Ollama server.
|
||||
|
||||
### Agents: Meta Reference
|
||||
Provides agent execution capabilities using the meta-reference implementation.
|
||||
|
||||
### Post-Training: Hugging Face
|
||||
Supports model fine-tuning using Hugging Face integration.
|
||||
|
||||
### Tool Runtime
|
||||
Supports various external tools including:
|
||||
- Brave Search
|
||||
- Tavily Search
|
||||
- RAG Runtime
|
||||
- Model Context Protocol
|
||||
- Wolfram Alpha
|
||||
|
||||
## Running Llama Stack with Ollama
|
||||
|
||||
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||
|
||||
### Via Docker
|
||||
|
||||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-ollama \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env OLLAMA_URL=$OLLAMA_URL \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
||||
```bash
|
||||
llama stack build --template ollama --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env OLLAMA_URL=$OLLAMA_URL \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
llama stack build --template ollama --image-type venv
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env OLLAMA_URL=$OLLAMA_URL \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
### Running with Safety
|
||||
|
||||
To enable safety checks, use the `run-with-safety.yaml` configuration:
|
||||
|
||||
```bash
|
||||
llama stack run ./run-with-safety.yaml \
|
||||
--port 8321 \
|
||||
--env OLLAMA_URL=$OLLAMA_URL \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL
|
||||
```
|
||||
|
||||
## Example Usage
|
||||
|
||||
Once your Llama Stack server is running with Ollama, you can interact with it using the Llama Stack client:
|
||||
|
||||
```python
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||
|
||||
# Run inference
|
||||
response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.2-3B-Instruct",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
)
|
||||
print(response.completion_message.content)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Connection refused errors**: Ensure your Ollama server is running and accessible at the configured URL.
|
||||
|
||||
2. **Model not found errors**: Make sure you've pulled the required models using `ollama pull <model-name>`.
|
||||
|
||||
3. **Performance issues**: Consider using more powerful models or adjusting the Ollama server configuration for better performance.
|
||||
|
||||
### Logs
|
||||
|
||||
Check the Ollama server logs for any issues:
|
||||
```bash
|
||||
# Ollama logs are typically available in:
|
||||
# - macOS: ~/Library/Logs/Ollama/
|
||||
# - Linux: ~/.ollama/logs/
|
||||
```
|
191
docs/source/getting_started/xdg_compliance.md
Normal file
191
docs/source/getting_started/xdg_compliance.md
Normal file
|
@ -0,0 +1,191 @@
|
|||
# XDG Base Directory Specification Compliance
|
||||
|
||||
Starting with version 0.2.14, Llama Stack follows the [XDG Base Directory Specification](https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html) for organizing configuration and data files. This provides better integration with modern desktop environments and allows for more flexible customization of where files are stored.
|
||||
|
||||
## Overview
|
||||
|
||||
The XDG Base Directory Specification defines standard locations for different types of application data:
|
||||
|
||||
- **Configuration files** (`XDG_CONFIG_HOME`): User-specific configuration files
|
||||
- **Data files** (`XDG_DATA_HOME`): User-specific data files that should persist
|
||||
- **Cache files** (`XDG_CACHE_HOME`): User-specific cache files
|
||||
- **State files** (`XDG_STATE_HOME`): User-specific state files
|
||||
|
||||
## Directory Mapping
|
||||
|
||||
Llama Stack now uses the following XDG-compliant directory structure:
|
||||
|
||||
| Data Type | XDG Directory | Default Location | Description |
|
||||
|-----------|---------------|------------------|-------------|
|
||||
| Configuration | `XDG_CONFIG_HOME` | `~/.config/llama-stack` | Distribution configs, provider configs |
|
||||
| Data | `XDG_DATA_HOME` | `~/.local/share/llama-stack` | Model checkpoints, persistent files |
|
||||
| Cache | `XDG_CACHE_HOME` | `~/.cache/llama-stack` | Temporary cache files |
|
||||
| State | `XDG_STATE_HOME` | `~/.local/state/llama-stack` | Runtime state, databases |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
You can customize the locations by setting these environment variables:
|
||||
|
||||
```bash
|
||||
# Override the base directories
|
||||
export XDG_CONFIG_HOME="/custom/config/path"
|
||||
export XDG_DATA_HOME="/custom/data/path"
|
||||
export XDG_CACHE_HOME="/custom/cache/path"
|
||||
export XDG_STATE_HOME="/custom/state/path"
|
||||
|
||||
# Or override specific Llama Stack directories
|
||||
export SQLITE_STORE_DIR="/custom/database/path"
|
||||
export FILES_STORAGE_DIR="/custom/files/path"
|
||||
```
|
||||
|
||||
## Backwards Compatibility
|
||||
|
||||
Llama Stack maintains full backwards compatibility with existing installations:
|
||||
|
||||
1. **Legacy Environment Variable**: If `LLAMA_STACK_CONFIG_DIR` is set, it will be used for all directories
|
||||
2. **Legacy Directory Detection**: If `~/.llama` exists and contains data, it will continue to be used
|
||||
3. **Gradual Migration**: New installations use XDG paths, existing installations continue to work
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### Automatic Migration
|
||||
|
||||
Use the built-in migration command to move from legacy `~/.llama` to XDG-compliant directories:
|
||||
|
||||
```bash
|
||||
# Preview what would be migrated
|
||||
llama migrate-xdg --dry-run
|
||||
|
||||
# Perform the migration
|
||||
llama migrate-xdg
|
||||
```
|
||||
|
||||
### Manual Migration
|
||||
|
||||
If you prefer to migrate manually, here's the mapping:
|
||||
|
||||
```bash
|
||||
# Create XDG directories
|
||||
mkdir -p ~/.config/llama-stack
|
||||
mkdir -p ~/.local/share/llama-stack
|
||||
mkdir -p ~/.local/state/llama-stack
|
||||
|
||||
# Move configuration files
|
||||
mv ~/.llama/distributions ~/.config/llama-stack/
|
||||
mv ~/.llama/providers.d ~/.config/llama-stack/
|
||||
|
||||
# Move data files
|
||||
mv ~/.llama/checkpoints ~/.local/share/llama-stack/
|
||||
|
||||
# Move state files
|
||||
mv ~/.llama/runtime ~/.local/state/llama-stack/
|
||||
|
||||
# Clean up empty legacy directory
|
||||
rmdir ~/.llama
|
||||
```
|
||||
|
||||
### Environment Variables in Configurations
|
||||
|
||||
Template configurations now use XDG-compliant environment variables:
|
||||
|
||||
```yaml
|
||||
# Old format
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
|
||||
|
||||
# New format
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/registry.db
|
||||
```
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Using Custom XDG Directories
|
||||
|
||||
```bash
|
||||
# Set custom XDG directories
|
||||
export XDG_CONFIG_HOME="/opt/llama-stack/config"
|
||||
export XDG_DATA_HOME="/opt/llama-stack/data"
|
||||
export XDG_STATE_HOME="/opt/llama-stack/state"
|
||||
|
||||
# Start Llama Stack
|
||||
llama stack run my-distribution.yaml
|
||||
```
|
||||
|
||||
### Using Legacy Directory
|
||||
|
||||
```bash
|
||||
# Continue using legacy directory
|
||||
export LLAMA_STACK_CONFIG_DIR="/home/user/.llama"
|
||||
|
||||
# Start Llama Stack
|
||||
llama stack run my-distribution.yaml
|
||||
```
|
||||
|
||||
### Custom Database and File Locations
|
||||
|
||||
```bash
|
||||
# Override specific directories
|
||||
export SQLITE_STORE_DIR="/fast/ssd/llama-stack/databases"
|
||||
export FILES_STORAGE_DIR="/large/disk/llama-stack/files"
|
||||
|
||||
# Start Llama Stack
|
||||
llama stack run my-distribution.yaml
|
||||
```
|
||||
|
||||
## Benefits of XDG Compliance
|
||||
|
||||
1. **Standards Compliance**: Follows established Linux/Unix conventions
|
||||
2. **Better Organization**: Separates configuration, data, cache, and state files
|
||||
3. **Flexibility**: Easy to customize storage locations
|
||||
4. **Backup-Friendly**: Easier to backup just data files or just configuration
|
||||
5. **Multi-User Support**: Better support for shared systems
|
||||
6. **Cross-Platform**: Works consistently across different environments
|
||||
|
||||
## Template Updates
|
||||
|
||||
All distribution templates have been updated to use XDG-compliant paths:
|
||||
|
||||
- Database files use `XDG_STATE_HOME`
|
||||
- Model checkpoints use `XDG_DATA_HOME`
|
||||
- Configuration files use `XDG_CONFIG_HOME`
|
||||
- Cache files use `XDG_CACHE_HOME`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Migration Issues
|
||||
|
||||
If you encounter issues during migration:
|
||||
|
||||
1. **Check Permissions**: Ensure you have write permissions to target directories
|
||||
2. **Disk Space**: Verify sufficient disk space in target locations
|
||||
3. **Existing Files**: Handle conflicts with existing files in target locations
|
||||
|
||||
### Environment Variable Conflicts
|
||||
|
||||
If you have multiple environment variables set:
|
||||
|
||||
1. `LLAMA_STACK_CONFIG_DIR` takes highest precedence
|
||||
2. Individual `XDG_*` variables override defaults
|
||||
3. Fallback to legacy `~/.llama` if it exists
|
||||
4. Default to XDG standard paths for new installations
|
||||
|
||||
### Debugging Path Resolution
|
||||
|
||||
To see which paths Llama Stack is using:
|
||||
|
||||
```python
|
||||
from llama_stack.distribution.utils.xdg_utils import (
|
||||
get_llama_stack_config_dir,
|
||||
get_llama_stack_data_dir,
|
||||
get_llama_stack_state_dir,
|
||||
)
|
||||
|
||||
print(f"Config: {get_llama_stack_config_dir()}")
|
||||
print(f"Data: {get_llama_stack_data_dir()}")
|
||||
print(f"State: {get_llama_stack_state_dir()}")
|
||||
```
|
||||
|
||||
## Future Considerations
|
||||
|
||||
- Container deployments will continue to use `/app` or similar paths
|
||||
- Cloud deployments may use provider-specific storage systems
|
||||
- The XDG specification primarily applies to local development and single-user systems
|
|
@ -16,10 +16,10 @@ Meta's reference implementation of an agent system that can use tools, access ve
|
|||
```yaml
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/agents_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/responses_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/responses_store.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ Local filesystem-based dataset I/O provider for reading and writing datasets to
|
|||
```yaml
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/localfs_datasetio.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/localfs_datasetio.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ HuggingFace datasets provider for accessing and managing datasets from the Huggi
|
|||
```yaml
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/huggingface_datasetio.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/huggingface_datasetio.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ Meta's reference implementation of evaluation tasks with support for multiple la
|
|||
```yaml
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/meta_reference_eval.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/meta_reference_eval.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -15,10 +15,10 @@ Local filesystem-based file storage provider for managing files and documents lo
|
|||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/dummy/files}
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/files_metadata.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/files_metadata.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -11,14 +11,14 @@ Meta's reference implementation of telemetry and observability using OpenTelemet
|
|||
| `otel_exporter_otlp_endpoint` | `str \| None` | No | | The OpenTelemetry collector endpoint URL (base URL for traces, metrics, and logs). If not set, the SDK will use OTEL_EXPORTER_OTLP_ENDPOINT environment variable. |
|
||||
| `service_name` | `<class 'str'>` | No | | The service name to use for telemetry |
|
||||
| `sinks` | `list[inline.telemetry.meta_reference.config.TelemetrySink` | No | [<TelemetrySink.CONSOLE: 'console'>, <TelemetrySink.SQLITE: 'sqlite'>] | List of telemetry sinks to enable (possible values: otel_trace, otel_metric, sqlite, console) |
|
||||
| `sqlite_db_path` | `<class 'str'>` | No | ~/.llama/runtime/trace_store.db | The path to the SQLite database to use for storing traces |
|
||||
| `sqlite_db_path` | `<class 'str'>` | No | ${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/runtime/trace_store.db | The path to the SQLite database to use for storing traces |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/trace_store.db
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
|
||||
```
|
||||
|
|
|
@ -44,7 +44,7 @@ more details about Faiss in general.
|
|||
```yaml
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/faiss_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/faiss_store.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ Meta's reference implementation of a vector database.
|
|||
```yaml
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/faiss_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/faiss_store.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -17,10 +17,10 @@ Please refer to the remote provider documentation.
|
|||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/dummy}/milvus.db
|
||||
db_path: ${env.MILVUS_DB_PATH:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/milvus.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/milvus_registry.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
|
|||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
path: ${env.QDRANT_PATH:=~/.llama/~/.llama/dummy}/qdrant.db
|
||||
path: ${env.QDRANT_PATH:=~/.llama/${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/qdrant.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -211,10 +211,10 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
|||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec_registry.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -16,10 +16,10 @@ Please refer to the sqlite-vec provider documentation.
|
|||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/sqlite_vec_registry.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ uri: ${env.MILVUS_ENDPOINT}
|
|||
token: ${env.MILVUS_TOKEN}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/milvus_remote_registry.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ user: ${env.PGVECTOR_USER}
|
|||
password: ${env.PGVECTOR_PASSWORD}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/pgvector_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/pgvector_registry.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
|
|||
```yaml
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy}/weaviate_registry.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import argparse
|
||||
|
||||
from .download import Download
|
||||
from .migrate_xdg import MigrateXDG
|
||||
from .model import ModelParser
|
||||
from .stack import StackParser
|
||||
from .stack.utils import print_subcommand_description
|
||||
|
@ -34,6 +35,7 @@ class LlamaCLIParser:
|
|||
StackParser.create(subparsers)
|
||||
Download.create(subparsers)
|
||||
VerifyDownload.create(subparsers)
|
||||
MigrateXDG.create(subparsers)
|
||||
|
||||
print_subcommand_description(self.parser, subparsers)
|
||||
|
||||
|
|
168
llama_stack/cli/migrate_xdg.py
Normal file
168
llama_stack/cli/migrate_xdg.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.xdg_utils import (
|
||||
get_llama_stack_config_dir,
|
||||
get_llama_stack_data_dir,
|
||||
get_llama_stack_state_dir,
|
||||
)
|
||||
|
||||
from .subcommand import Subcommand
|
||||
|
||||
|
||||
class MigrateXDG(Subcommand):
|
||||
"""CLI command for migrating from legacy ~/.llama to XDG-compliant directories."""
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"migrate-xdg",
|
||||
prog="llama migrate-xdg",
|
||||
description="Migrate from legacy ~/.llama to XDG-compliant directories",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--dry-run", action="store_true", help="Show what would be done without actually moving files"
|
||||
)
|
||||
|
||||
self.parser.set_defaults(func=self._run_migrate_xdg_cmd)
|
||||
|
||||
@staticmethod
|
||||
def create(subparsers: argparse._SubParsersAction):
|
||||
return MigrateXDG(subparsers)
|
||||
|
||||
def _run_migrate_xdg_cmd(self, args: argparse.Namespace) -> None:
|
||||
"""Run the migrate-xdg command."""
|
||||
if not migrate_to_xdg(dry_run=args.dry_run):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def migrate_to_xdg(dry_run: bool = False) -> bool:
|
||||
"""
|
||||
Migrate from legacy ~/.llama to XDG-compliant directories.
|
||||
|
||||
Args:
|
||||
dry_run: If True, only show what would be done without actually moving files
|
||||
|
||||
Returns:
|
||||
bool: True if migration was successful or not needed, False otherwise
|
||||
"""
|
||||
legacy_path = Path.home() / ".llama"
|
||||
|
||||
if not legacy_path.exists():
|
||||
print("No legacy ~/.llama directory found. Nothing to migrate.")
|
||||
return True
|
||||
|
||||
# Check if we're already using XDG paths
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
data_dir = get_llama_stack_data_dir()
|
||||
state_dir = get_llama_stack_state_dir()
|
||||
|
||||
if str(config_dir) == str(legacy_path):
|
||||
print("Already using legacy directory. No migration needed.")
|
||||
return True
|
||||
|
||||
print(f"Found legacy directory at: {legacy_path}")
|
||||
print("Will migrate to XDG-compliant directories:")
|
||||
print(f" Config: {config_dir}")
|
||||
print(f" Data: {data_dir}")
|
||||
print(f" State: {state_dir}")
|
||||
print()
|
||||
|
||||
# Define migration mapping
|
||||
migrations = [
|
||||
# (source_subdir, target_base_dir, description)
|
||||
("distributions", config_dir, "Distribution configurations"),
|
||||
("providers.d", config_dir, "External provider configurations"),
|
||||
("checkpoints", data_dir, "Model checkpoints"),
|
||||
("runtime", state_dir, "Runtime state files"),
|
||||
]
|
||||
|
||||
# Check what needs to be migrated
|
||||
items_to_migrate = []
|
||||
for subdir, target_base, description in migrations:
|
||||
source_path = legacy_path / subdir
|
||||
if source_path.exists():
|
||||
target_path = target_base / subdir
|
||||
items_to_migrate.append((source_path, target_path, description))
|
||||
|
||||
if not items_to_migrate:
|
||||
print("No items found to migrate.")
|
||||
return True
|
||||
|
||||
print("Items to migrate:")
|
||||
for source_path, target_path, description in items_to_migrate:
|
||||
print(f" {description}: {source_path} -> {target_path}")
|
||||
|
||||
if dry_run:
|
||||
print("\nDry run mode: No files will be moved.")
|
||||
return True
|
||||
|
||||
# Ask for confirmation
|
||||
response = input("\nDo you want to proceed with the migration? (y/N): ")
|
||||
if response.lower() not in ["y", "yes"]:
|
||||
print("Migration cancelled.")
|
||||
return False
|
||||
|
||||
# Perform the migration
|
||||
print("\nMigrating files...")
|
||||
|
||||
for source_path, target_path, description in items_to_migrate:
|
||||
try:
|
||||
# Create target directory if it doesn't exist
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if target already exists
|
||||
if target_path.exists():
|
||||
print(f" Warning: Target already exists: {target_path}")
|
||||
print(f" Skipping {description}")
|
||||
continue
|
||||
|
||||
# Move the directory
|
||||
shutil.move(str(source_path), str(target_path))
|
||||
print(f" Moved {description}: {source_path} -> {target_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error migrating {description}: {e}")
|
||||
return False
|
||||
|
||||
# Check if legacy directory is now empty (except for hidden files)
|
||||
remaining_items = [item for item in legacy_path.iterdir() if not item.name.startswith(".")]
|
||||
if not remaining_items:
|
||||
print(f"\nMigration complete! Legacy directory {legacy_path} is now empty.")
|
||||
response = input("Remove empty legacy directory? (y/N): ")
|
||||
if response.lower() in ["y", "yes"]:
|
||||
try:
|
||||
shutil.rmtree(legacy_path)
|
||||
print(f"Removed empty legacy directory: {legacy_path}")
|
||||
except Exception as e:
|
||||
print(f"Could not remove legacy directory: {e}")
|
||||
else:
|
||||
print(f"\nMigration complete! Some items remain in legacy directory: {remaining_items}")
|
||||
|
||||
print("\nMigration successful!")
|
||||
print("You may need to update any custom scripts or configurations that reference the old paths.")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Migrate from legacy ~/.llama to XDG-compliant directories")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Show what would be done without actually moving files")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not migrate_to_xdg(dry_run=args.dry_run):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -7,12 +7,35 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
|
||||
from .xdg_utils import (
|
||||
get_llama_stack_config_dir,
|
||||
get_llama_stack_data_dir,
|
||||
get_llama_stack_state_dir,
|
||||
)
|
||||
|
||||
# Base directory for all llama-stack configuration
|
||||
# This now uses XDG-compliant paths with backwards compatibility
|
||||
LLAMA_STACK_CONFIG_DIR = get_llama_stack_config_dir()
|
||||
|
||||
# Distribution configurations - stored in config directory
|
||||
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
# Model checkpoints - stored in data directory (persistent data)
|
||||
DEFAULT_CHECKPOINT_DIR = get_llama_stack_data_dir() / "checkpoints"
|
||||
|
||||
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
||||
# Runtime data - stored in state directory
|
||||
RUNTIME_BASE_DIR = get_llama_stack_state_dir() / "runtime"
|
||||
|
||||
# External providers - stored in config directory
|
||||
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"
|
||||
|
||||
# Legacy compatibility: if the legacy environment variable is set, use it for all paths
|
||||
# This ensures that existing installations continue to work
|
||||
legacy_config_dir = os.getenv("LLAMA_STACK_CONFIG_DIR")
|
||||
if legacy_config_dir:
|
||||
legacy_base = Path(legacy_config_dir)
|
||||
LLAMA_STACK_CONFIG_DIR = legacy_base
|
||||
DISTRIBS_BASE_DIR = legacy_base / "distributions"
|
||||
DEFAULT_CHECKPOINT_DIR = legacy_base / "checkpoints"
|
||||
RUNTIME_BASE_DIR = legacy_base / "runtime"
|
||||
EXTERNAL_PROVIDERS_DIR = legacy_base / "providers.d"
|
||||
|
|
216
llama_stack/distribution/utils/xdg_utils.py
Normal file
216
llama_stack/distribution/utils/xdg_utils.py
Normal file
|
@ -0,0 +1,216 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_xdg_config_home() -> Path:
|
||||
"""
|
||||
Get the XDG config home directory.
|
||||
|
||||
Returns:
|
||||
Path: XDG_CONFIG_HOME if set, otherwise ~/.config
|
||||
"""
|
||||
return Path(os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")))
|
||||
|
||||
|
||||
def get_xdg_data_home() -> Path:
|
||||
"""
|
||||
Get the XDG data home directory.
|
||||
|
||||
Returns:
|
||||
Path: XDG_DATA_HOME if set, otherwise ~/.local/share
|
||||
"""
|
||||
return Path(os.environ.get("XDG_DATA_HOME", os.path.expanduser("~/.local/share")))
|
||||
|
||||
|
||||
def get_xdg_cache_home() -> Path:
|
||||
"""
|
||||
Get the XDG cache home directory.
|
||||
|
||||
Returns:
|
||||
Path: XDG_CACHE_HOME if set, otherwise ~/.cache
|
||||
"""
|
||||
return Path(os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")))
|
||||
|
||||
|
||||
def get_xdg_state_home() -> Path:
|
||||
"""
|
||||
Get the XDG state home directory.
|
||||
|
||||
Returns:
|
||||
Path: XDG_STATE_HOME if set, otherwise ~/.local/state
|
||||
"""
|
||||
return Path(os.environ.get("XDG_STATE_HOME", os.path.expanduser("~/.local/state")))
|
||||
|
||||
|
||||
def get_llama_stack_config_dir() -> Path:
|
||||
"""
|
||||
Get the llama-stack configuration directory.
|
||||
|
||||
This function provides backwards compatibility by checking for the legacy
|
||||
LLAMA_STACK_CONFIG_DIR environment variable first, then falling back to
|
||||
XDG-compliant paths.
|
||||
|
||||
Returns:
|
||||
Path: Configuration directory for llama-stack
|
||||
"""
|
||||
# Check for legacy environment variable first for backwards compatibility
|
||||
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
|
||||
if legacy_dir:
|
||||
return Path(legacy_dir)
|
||||
|
||||
# Check if legacy ~/.llama directory exists and contains data
|
||||
legacy_path = Path.home() / ".llama"
|
||||
if legacy_path.exists() and any(legacy_path.iterdir()):
|
||||
return legacy_path
|
||||
|
||||
# Use XDG-compliant path
|
||||
return get_xdg_config_home() / "llama-stack"
|
||||
|
||||
|
||||
def get_llama_stack_data_dir() -> Path:
|
||||
"""
|
||||
Get the llama-stack data directory.
|
||||
|
||||
This is used for persistent data like model checkpoints.
|
||||
|
||||
Returns:
|
||||
Path: Data directory for llama-stack
|
||||
"""
|
||||
# Check for legacy environment variable first for backwards compatibility
|
||||
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
|
||||
if legacy_dir:
|
||||
return Path(legacy_dir)
|
||||
|
||||
# Check if legacy ~/.llama directory exists and contains data
|
||||
legacy_path = Path.home() / ".llama"
|
||||
if legacy_path.exists() and any(legacy_path.iterdir()):
|
||||
return legacy_path
|
||||
|
||||
# Use XDG-compliant path
|
||||
return get_xdg_data_home() / "llama-stack"
|
||||
|
||||
|
||||
def get_llama_stack_cache_dir() -> Path:
|
||||
"""
|
||||
Get the llama-stack cache directory.
|
||||
|
||||
This is used for temporary/cache data.
|
||||
|
||||
Returns:
|
||||
Path: Cache directory for llama-stack
|
||||
"""
|
||||
# Check for legacy environment variable first for backwards compatibility
|
||||
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
|
||||
if legacy_dir:
|
||||
return Path(legacy_dir)
|
||||
|
||||
# Check if legacy ~/.llama directory exists and contains data
|
||||
legacy_path = Path.home() / ".llama"
|
||||
if legacy_path.exists() and any(legacy_path.iterdir()):
|
||||
return legacy_path
|
||||
|
||||
# Use XDG-compliant path
|
||||
return get_xdg_cache_home() / "llama-stack"
|
||||
|
||||
|
||||
def get_llama_stack_state_dir() -> Path:
|
||||
"""
|
||||
Get the llama-stack state directory.
|
||||
|
||||
This is used for runtime state data.
|
||||
|
||||
Returns:
|
||||
Path: State directory for llama-stack
|
||||
"""
|
||||
# Check for legacy environment variable first for backwards compatibility
|
||||
legacy_dir = os.environ.get("LLAMA_STACK_CONFIG_DIR")
|
||||
if legacy_dir:
|
||||
return Path(legacy_dir)
|
||||
|
||||
# Check if legacy ~/.llama directory exists and contains data
|
||||
legacy_path = Path.home() / ".llama"
|
||||
if legacy_path.exists() and any(legacy_path.iterdir()):
|
||||
return legacy_path
|
||||
|
||||
# Use XDG-compliant path
|
||||
return get_xdg_state_home() / "llama-stack"
|
||||
|
||||
|
||||
def get_xdg_compliant_path(path_type: str, subdirectory: str | None = None, legacy_fallback: bool = True) -> Path:
|
||||
"""
|
||||
Get an XDG-compliant path for a given type.
|
||||
|
||||
Args:
|
||||
path_type: Type of path ('config', 'data', 'cache', 'state')
|
||||
subdirectory: Optional subdirectory within the base path
|
||||
legacy_fallback: Whether to check for legacy ~/.llama directory
|
||||
|
||||
Returns:
|
||||
Path: XDG-compliant path
|
||||
|
||||
Raises:
|
||||
ValueError: If path_type is not recognized
|
||||
"""
|
||||
path_map = {
|
||||
"config": get_llama_stack_config_dir,
|
||||
"data": get_llama_stack_data_dir,
|
||||
"cache": get_llama_stack_cache_dir,
|
||||
"state": get_llama_stack_state_dir,
|
||||
}
|
||||
|
||||
if path_type not in path_map:
|
||||
raise ValueError(f"Unknown path type: {path_type}. Must be one of: {list(path_map.keys())}")
|
||||
|
||||
base_path = path_map[path_type]()
|
||||
|
||||
if subdirectory:
|
||||
return base_path / subdirectory
|
||||
|
||||
return base_path
|
||||
|
||||
|
||||
def migrate_legacy_directory() -> bool:
|
||||
"""
|
||||
Migrate from legacy ~/.llama directory to XDG-compliant directories.
|
||||
|
||||
This function helps users migrate their existing data to the new
|
||||
XDG-compliant structure.
|
||||
|
||||
Returns:
|
||||
bool: True if migration was successful or not needed, False otherwise
|
||||
"""
|
||||
legacy_path = Path.home() / ".llama"
|
||||
|
||||
if not legacy_path.exists():
|
||||
return True # No migration needed
|
||||
|
||||
print(f"Found legacy directory at {legacy_path}")
|
||||
print("Consider migrating to XDG-compliant directories:")
|
||||
print(f" Config: {get_llama_stack_config_dir()}")
|
||||
print(f" Data: {get_llama_stack_data_dir()}")
|
||||
print(f" Cache: {get_llama_stack_cache_dir()}")
|
||||
print(f" State: {get_llama_stack_state_dir()}")
|
||||
print("Migration can be done by moving the appropriate subdirectories.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def ensure_directory_exists(path: Path) -> None:
|
||||
"""
|
||||
Ensure a directory exists, creating it if necessary.
|
||||
|
||||
Args:
|
||||
path: Path to the directory
|
||||
"""
|
||||
path.mkdir(parents=True, exist_ok=True)
|
|
@ -4,6 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
|
|
|
@ -12,23 +12,12 @@ Type-safe data interchange for Python data classes.
|
|||
|
||||
import dataclasses
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import is_dataclass
|
||||
from typing import Callable, Dict, Optional, Type, TypeVar, Union, overload
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated as Annotated
|
||||
else:
|
||||
from typing_extensions import Annotated as Annotated
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias as TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias as TypeAlias
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import dataclass_transform as dataclass_transform
|
||||
else:
|
||||
from typing_extensions import dataclass_transform as dataclass_transform
|
||||
from typing import Annotated as Annotated
|
||||
from typing import TypeAlias as TypeAlias
|
||||
from typing import TypeVar, overload
|
||||
from typing import dataclass_transform as dataclass_transform
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -56,17 +45,17 @@ class CompactDataClass:
|
|||
|
||||
|
||||
@overload
|
||||
def typeannotation(cls: Type[T], /) -> Type[T]: ...
|
||||
def typeannotation(cls: type[T], /) -> type[T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[Type[T]], Type[T]]: ...
|
||||
def typeannotation(cls: None, *, eq: bool = True, order: bool = False) -> Callable[[type[T]], type[T]]: ...
|
||||
|
||||
|
||||
@dataclass_transform(eq_default=True, order_default=False)
|
||||
def typeannotation(
|
||||
cls: Optional[Type[T]] = None, *, eq: bool = True, order: bool = False
|
||||
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
|
||||
cls: type[T] | None = None, *, eq: bool = True, order: bool = False
|
||||
) -> type[T] | Callable[[type[T]], type[T]]:
|
||||
"""
|
||||
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
|
||||
|
||||
|
@ -76,7 +65,7 @@ def typeannotation(
|
|||
:returns: A data-class type, or a wrapper for data-class types.
|
||||
"""
|
||||
|
||||
def wrap(cls: Type[T]) -> Type[T]:
|
||||
def wrap(cls: type[T]) -> type[T]:
|
||||
# mypy fails to equate bound-y functions (first argument interpreted as
|
||||
# the bound object) with class methods, hence the `ignore` directive.
|
||||
cls.__repr__ = _compact_dataclass_repr # type: ignore[method-assign]
|
||||
|
@ -179,41 +168,41 @@ class SpecialConversion:
|
|||
"Indicates that the annotated type is subject to custom conversion rules."
|
||||
|
||||
|
||||
int8: TypeAlias = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)]
|
||||
int16: TypeAlias = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)]
|
||||
int32: TypeAlias = Annotated[
|
||||
type int8 = Annotated[int, Signed(True), Storage(1), IntegerRange(-128, 127)]
|
||||
type int16 = Annotated[int, Signed(True), Storage(2), IntegerRange(-32768, 32767)]
|
||||
type int32 = Annotated[
|
||||
int,
|
||||
Signed(True),
|
||||
Storage(4),
|
||||
IntegerRange(-2147483648, 2147483647),
|
||||
]
|
||||
int64: TypeAlias = Annotated[
|
||||
type int64 = Annotated[
|
||||
int,
|
||||
Signed(True),
|
||||
Storage(8),
|
||||
IntegerRange(-9223372036854775808, 9223372036854775807),
|
||||
]
|
||||
|
||||
uint8: TypeAlias = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)]
|
||||
uint16: TypeAlias = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)]
|
||||
uint32: TypeAlias = Annotated[
|
||||
type uint8 = Annotated[int, Signed(False), Storage(1), IntegerRange(0, 255)]
|
||||
type uint16 = Annotated[int, Signed(False), Storage(2), IntegerRange(0, 65535)]
|
||||
type uint32 = Annotated[
|
||||
int,
|
||||
Signed(False),
|
||||
Storage(4),
|
||||
IntegerRange(0, 4294967295),
|
||||
]
|
||||
uint64: TypeAlias = Annotated[
|
||||
type uint64 = Annotated[
|
||||
int,
|
||||
Signed(False),
|
||||
Storage(8),
|
||||
IntegerRange(0, 18446744073709551615),
|
||||
]
|
||||
|
||||
float32: TypeAlias = Annotated[float, Storage(4)]
|
||||
float64: TypeAlias = Annotated[float, Storage(8)]
|
||||
type float32 = Annotated[float, Storage(4)]
|
||||
type float64 = Annotated[float, Storage(8)]
|
||||
|
||||
# maps globals of type Annotated[T, ...] defined in this module to their string names
|
||||
_auxiliary_types: Dict[object, str] = {}
|
||||
_auxiliary_types: dict[object, str] = {}
|
||||
module = sys.modules[__name__]
|
||||
for var in dir(module):
|
||||
typ = getattr(module, var)
|
||||
|
@ -222,7 +211,7 @@ for var in dir(module):
|
|||
_auxiliary_types[typ] = var
|
||||
|
||||
|
||||
def get_auxiliary_format(data_type: object) -> Optional[str]:
|
||||
def get_auxiliary_format(data_type: object) -> str | None:
|
||||
"Returns the JSON format string corresponding to an auxiliary type."
|
||||
|
||||
return _auxiliary_types.get(data_type)
|
||||
|
|
|
@ -12,12 +12,11 @@ import enum
|
|||
import ipaddress
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Literal, TypeVar, Union
|
||||
|
||||
from .auxiliary import (
|
||||
Alias,
|
||||
|
@ -40,57 +39,57 @@ T = TypeVar("T")
|
|||
|
||||
@dataclass
|
||||
class JsonSchemaNode:
|
||||
title: Optional[str]
|
||||
description: Optional[str]
|
||||
title: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaType(JsonSchemaNode):
|
||||
type: str
|
||||
format: Optional[str]
|
||||
format: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaBoolean(JsonSchemaType):
|
||||
type: Literal["boolean"]
|
||||
const: Optional[bool]
|
||||
default: Optional[bool]
|
||||
examples: Optional[List[bool]]
|
||||
const: bool | None
|
||||
default: bool | None
|
||||
examples: list[bool] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaInteger(JsonSchemaType):
|
||||
type: Literal["integer"]
|
||||
const: Optional[int]
|
||||
default: Optional[int]
|
||||
examples: Optional[List[int]]
|
||||
enum: Optional[List[int]]
|
||||
minimum: Optional[int]
|
||||
maximum: Optional[int]
|
||||
const: int | None
|
||||
default: int | None
|
||||
examples: list[int] | None
|
||||
enum: list[int] | None
|
||||
minimum: int | None
|
||||
maximum: int | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaNumber(JsonSchemaType):
|
||||
type: Literal["number"]
|
||||
const: Optional[float]
|
||||
default: Optional[float]
|
||||
examples: Optional[List[float]]
|
||||
minimum: Optional[float]
|
||||
maximum: Optional[float]
|
||||
exclusiveMinimum: Optional[float]
|
||||
exclusiveMaximum: Optional[float]
|
||||
multipleOf: Optional[float]
|
||||
const: float | None
|
||||
default: float | None
|
||||
examples: list[float] | None
|
||||
minimum: float | None
|
||||
maximum: float | None
|
||||
exclusiveMinimum: float | None
|
||||
exclusiveMaximum: float | None
|
||||
multipleOf: float | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaString(JsonSchemaType):
|
||||
type: Literal["string"]
|
||||
const: Optional[str]
|
||||
default: Optional[str]
|
||||
examples: Optional[List[str]]
|
||||
enum: Optional[List[str]]
|
||||
minLength: Optional[int]
|
||||
maxLength: Optional[int]
|
||||
const: str | None
|
||||
default: str | None
|
||||
examples: list[str] | None
|
||||
enum: list[str] | None
|
||||
minLength: int | None
|
||||
maxLength: int | None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -102,9 +101,9 @@ class JsonSchemaArray(JsonSchemaType):
|
|||
@dataclass
|
||||
class JsonSchemaObject(JsonSchemaType):
|
||||
type: Literal["object"]
|
||||
properties: Optional[Dict[str, "JsonSchemaAny"]]
|
||||
additionalProperties: Optional[bool]
|
||||
required: Optional[List[str]]
|
||||
properties: dict[str, "JsonSchemaAny"] | None
|
||||
additionalProperties: bool | None
|
||||
required: list[str] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -114,24 +113,24 @@ class JsonSchemaRef(JsonSchemaNode):
|
|||
|
||||
@dataclass
|
||||
class JsonSchemaAllOf(JsonSchemaNode):
|
||||
allOf: List["JsonSchemaAny"]
|
||||
allOf: list["JsonSchemaAny"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaAnyOf(JsonSchemaNode):
|
||||
anyOf: List["JsonSchemaAny"]
|
||||
anyOf: list["JsonSchemaAny"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Discriminator:
|
||||
propertyName: str
|
||||
mapping: Dict[str, str]
|
||||
mapping: dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonSchemaOneOf(JsonSchemaNode):
|
||||
oneOf: List["JsonSchemaAny"]
|
||||
discriminator: Optional[Discriminator]
|
||||
oneOf: list["JsonSchemaAny"]
|
||||
discriminator: Discriminator | None
|
||||
|
||||
|
||||
JsonSchemaAny = Union[
|
||||
|
@ -149,10 +148,10 @@ JsonSchemaAny = Union[
|
|||
@dataclass
|
||||
class JsonSchemaTopLevelObject(JsonSchemaObject):
|
||||
schema: Annotated[str, Alias("$schema")]
|
||||
definitions: Optional[Dict[str, JsonSchemaAny]]
|
||||
definitions: dict[str, JsonSchemaAny] | None
|
||||
|
||||
|
||||
def integer_range_to_type(min_value: float, max_value: float) -> type:
|
||||
def integer_range_to_type(min_value: float, max_value: float) -> Any:
|
||||
if min_value >= -(2**15) and max_value < 2**15:
|
||||
return int16
|
||||
elif min_value >= -(2**31) and max_value < 2**31:
|
||||
|
@ -173,11 +172,11 @@ def enum_safe_name(name: str) -> str:
|
|||
def enum_values_to_type(
|
||||
module: types.ModuleType,
|
||||
name: str,
|
||||
values: Dict[str, Any],
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> Type[enum.Enum]:
|
||||
enum_class: Type[enum.Enum] = enum.Enum(name, values) # type: ignore
|
||||
values: dict[str, Any],
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> type[enum.Enum]:
|
||||
enum_class: type[enum.Enum] = enum.Enum(name, values) # type: ignore
|
||||
|
||||
# assign the newly created type to the same module where the defining class is
|
||||
enum_class.__module__ = module.__name__
|
||||
|
@ -330,7 +329,7 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode
|
|||
type_def = node_to_typedef(module, context, node.items)
|
||||
if type_def.default is not dataclasses.MISSING:
|
||||
raise TypeError("disallowed: `default` for array element type")
|
||||
list_type = List[(type_def.type,)] # type: ignore
|
||||
list_type = list[(type_def.type,)] # type: ignore
|
||||
return TypeDef(list_type, dataclasses.MISSING)
|
||||
|
||||
elif isinstance(node, JsonSchemaObject):
|
||||
|
@ -344,8 +343,8 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode
|
|||
|
||||
class_name = context
|
||||
|
||||
fields: List[Tuple[str, Any, dataclasses.Field]] = []
|
||||
params: Dict[str, DocstringParam] = {}
|
||||
fields: list[tuple[str, Any, dataclasses.Field]] = []
|
||||
params: dict[str, DocstringParam] = {}
|
||||
for prop_name, prop_node in node.properties.items():
|
||||
type_def = node_to_typedef(module, f"{class_name}__{prop_name}", prop_node)
|
||||
if prop_name in required:
|
||||
|
@ -358,10 +357,7 @@ def node_to_typedef(module: types.ModuleType, context: str, node: JsonSchemaNode
|
|||
params[prop_name] = DocstringParam(prop_name, prop_desc)
|
||||
|
||||
fields.sort(key=lambda t: t[2].default is not dataclasses.MISSING)
|
||||
if sys.version_info >= (3, 12):
|
||||
class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__)
|
||||
else:
|
||||
class_type = dataclasses.make_dataclass(class_name, fields, namespace={"__module__": module.__name__})
|
||||
class_type = dataclasses.make_dataclass(class_name, fields, module=module.__name__)
|
||||
class_type.__doc__ = str(
|
||||
Docstring(
|
||||
short_description=node.title,
|
||||
|
@ -388,7 +384,7 @@ class SchemaFlatteningOptions:
|
|||
recursive: bool = False
|
||||
|
||||
|
||||
def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions] = None) -> Schema:
|
||||
def flatten_schema(schema: Schema, *, options: SchemaFlatteningOptions | None = None) -> Schema:
|
||||
top_node = typing.cast(JsonSchemaTopLevelObject, json_to_object(JsonSchemaTopLevelObject, schema))
|
||||
flattener = SchemaFlattener(options)
|
||||
obj = flattener.flatten(top_node)
|
||||
|
@ -398,7 +394,7 @@ def flatten_schema(schema: Schema, *, options: Optional[SchemaFlatteningOptions]
|
|||
class SchemaFlattener:
|
||||
options: SchemaFlatteningOptions
|
||||
|
||||
def __init__(self, options: Optional[SchemaFlatteningOptions] = None) -> None:
|
||||
def __init__(self, options: SchemaFlatteningOptions | None = None) -> None:
|
||||
self.options = options or SchemaFlatteningOptions()
|
||||
|
||||
def flatten(self, source_node: JsonSchemaObject) -> JsonSchemaObject:
|
||||
|
@ -406,10 +402,10 @@ class SchemaFlattener:
|
|||
return source_node
|
||||
|
||||
source_props = source_node.properties or {}
|
||||
target_props: Dict[str, JsonSchemaAny] = {}
|
||||
target_props: dict[str, JsonSchemaAny] = {}
|
||||
|
||||
source_reqs = source_node.required or []
|
||||
target_reqs: List[str] = []
|
||||
target_reqs: list[str] = []
|
||||
|
||||
for name, prop in source_props.items():
|
||||
if not isinstance(prop, JsonSchemaObject):
|
||||
|
|
|
@ -10,7 +10,7 @@ Type-safe data interchange for Python data classes.
|
|||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Union
|
||||
from typing import Union
|
||||
|
||||
|
||||
class JsonObject:
|
||||
|
@ -28,8 +28,8 @@ JsonType = Union[
|
|||
int,
|
||||
float,
|
||||
str,
|
||||
Dict[str, "JsonType"],
|
||||
List["JsonType"],
|
||||
dict[str, "JsonType"],
|
||||
list["JsonType"],
|
||||
]
|
||||
|
||||
# a JSON type that cannot contain `null` values
|
||||
|
@ -38,9 +38,9 @@ StrictJsonType = Union[
|
|||
int,
|
||||
float,
|
||||
str,
|
||||
Dict[str, "StrictJsonType"],
|
||||
List["StrictJsonType"],
|
||||
dict[str, "StrictJsonType"],
|
||||
list["StrictJsonType"],
|
||||
]
|
||||
|
||||
# a meta-type that captures the object type in a JSON schema
|
||||
Schema = Dict[str, JsonType]
|
||||
Schema = dict[str, JsonType]
|
||||
|
|
|
@ -20,19 +20,14 @@ import ipaddress
|
|||
import sys
|
||||
import typing
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from types import ModuleType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
@ -70,7 +65,7 @@ V = TypeVar("V")
|
|||
class Deserializer(abc.ABC, Generic[T]):
|
||||
"Parses a JSON value into a Python type."
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
"""
|
||||
Creates auxiliary parsers that this parser is depending on.
|
||||
|
||||
|
@ -203,19 +198,19 @@ class IPv6Deserializer(Deserializer[ipaddress.IPv6Address]):
|
|||
return ipaddress.IPv6Address(data)
|
||||
|
||||
|
||||
class ListDeserializer(Deserializer[List[T]]):
|
||||
class ListDeserializer(Deserializer[list[T]]):
|
||||
"Recursively de-serializes a JSON array into a Python `list`."
|
||||
|
||||
item_type: Type[T]
|
||||
item_type: type[T]
|
||||
item_parser: Deserializer
|
||||
|
||||
def __init__(self, item_type: Type[T]) -> None:
|
||||
def __init__(self, item_type: type[T]) -> None:
|
||||
self.item_type = item_type
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
self.item_parser = _get_deserializer(self.item_type, context)
|
||||
|
||||
def parse(self, data: JsonType) -> List[T]:
|
||||
def parse(self, data: JsonType) -> list[T]:
|
||||
if not isinstance(data, list):
|
||||
type_name = python_type_to_str(self.item_type)
|
||||
raise JsonTypeError(f"type `List[{type_name}]` expects JSON `array` data but instead received: {data}")
|
||||
|
@ -223,19 +218,19 @@ class ListDeserializer(Deserializer[List[T]]):
|
|||
return [self.item_parser.parse(item) for item in data]
|
||||
|
||||
|
||||
class DictDeserializer(Deserializer[Dict[K, V]]):
|
||||
class DictDeserializer(Deserializer[dict[K, V]]):
|
||||
"Recursively de-serializes a JSON object into a Python `dict`."
|
||||
|
||||
key_type: Type[K]
|
||||
value_type: Type[V]
|
||||
key_type: type[K]
|
||||
value_type: type[V]
|
||||
value_parser: Deserializer[V]
|
||||
|
||||
def __init__(self, key_type: Type[K], value_type: Type[V]) -> None:
|
||||
def __init__(self, key_type: type[K], value_type: type[V]) -> None:
|
||||
self.key_type = key_type
|
||||
self.value_type = value_type
|
||||
self._check_key_type()
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
self.value_parser = _get_deserializer(self.value_type, context)
|
||||
|
||||
def _check_key_type(self) -> None:
|
||||
|
@ -264,48 +259,48 @@ class DictDeserializer(Deserializer[Dict[K, V]]):
|
|||
value_type_name = python_type_to_str(self.value_type)
|
||||
return f"Dict[{key_type_name}, {value_type_name}]"
|
||||
|
||||
def parse(self, data: JsonType) -> Dict[K, V]:
|
||||
def parse(self, data: JsonType) -> dict[K, V]:
|
||||
if not isinstance(data, dict):
|
||||
raise JsonTypeError(
|
||||
f"`type `{self.container_type}` expects JSON `object` data but instead received: {data}"
|
||||
)
|
||||
|
||||
return dict(
|
||||
(self.key_type(key), self.value_parser.parse(value)) # type: ignore[call-arg]
|
||||
return {
|
||||
self.key_type(key): self.value_parser.parse(value) # type: ignore[call-arg]
|
||||
for key, value in data.items()
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class SetDeserializer(Deserializer[Set[T]]):
|
||||
class SetDeserializer(Deserializer[set[T]]):
|
||||
"Recursively de-serializes a JSON list into a Python `set`."
|
||||
|
||||
member_type: Type[T]
|
||||
member_type: type[T]
|
||||
member_parser: Deserializer
|
||||
|
||||
def __init__(self, member_type: Type[T]) -> None:
|
||||
def __init__(self, member_type: type[T]) -> None:
|
||||
self.member_type = member_type
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
self.member_parser = _get_deserializer(self.member_type, context)
|
||||
|
||||
def parse(self, data: JsonType) -> Set[T]:
|
||||
def parse(self, data: JsonType) -> set[T]:
|
||||
if not isinstance(data, list):
|
||||
type_name = python_type_to_str(self.member_type)
|
||||
raise JsonTypeError(f"type `Set[{type_name}]` expects JSON `array` data but instead received: {data}")
|
||||
|
||||
return set(self.member_parser.parse(item) for item in data)
|
||||
return {self.member_parser.parse(item) for item in data}
|
||||
|
||||
|
||||
class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
|
||||
class TupleDeserializer(Deserializer[tuple[Any, ...]]):
|
||||
"Recursively de-serializes a JSON list into a Python `tuple`."
|
||||
|
||||
item_types: Tuple[Type[Any], ...]
|
||||
item_parsers: Tuple[Deserializer[Any], ...]
|
||||
item_types: tuple[type[Any], ...]
|
||||
item_parsers: tuple[Deserializer[Any], ...]
|
||||
|
||||
def __init__(self, item_types: Tuple[Type[Any], ...]) -> None:
|
||||
def __init__(self, item_types: tuple[type[Any], ...]) -> None:
|
||||
self.item_types = item_types
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
self.item_parsers = tuple(_get_deserializer(item_type, context) for item_type in self.item_types)
|
||||
|
||||
@property
|
||||
|
@ -313,7 +308,7 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
|
|||
type_names = ", ".join(python_type_to_str(item_type) for item_type in self.item_types)
|
||||
return f"Tuple[{type_names}]"
|
||||
|
||||
def parse(self, data: JsonType) -> Tuple[Any, ...]:
|
||||
def parse(self, data: JsonType) -> tuple[Any, ...]:
|
||||
if not isinstance(data, list) or len(data) != len(self.item_parsers):
|
||||
if not isinstance(data, list):
|
||||
raise JsonTypeError(
|
||||
|
@ -331,13 +326,13 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]):
|
|||
class UnionDeserializer(Deserializer):
|
||||
"De-serializes a JSON value (of any type) into a Python union type."
|
||||
|
||||
member_types: Tuple[type, ...]
|
||||
member_parsers: Tuple[Deserializer, ...]
|
||||
member_types: tuple[type, ...]
|
||||
member_parsers: tuple[Deserializer, ...]
|
||||
|
||||
def __init__(self, member_types: Tuple[type, ...]) -> None:
|
||||
def __init__(self, member_types: tuple[type, ...]) -> None:
|
||||
self.member_types = member_types
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
self.member_parsers = tuple(_get_deserializer(member_type, context) for member_type in self.member_types)
|
||||
|
||||
def parse(self, data: JsonType) -> Any:
|
||||
|
@ -354,15 +349,15 @@ class UnionDeserializer(Deserializer):
|
|||
raise JsonKeyError(f"type `Union[{type_names}]` could not be instantiated from: {data}")
|
||||
|
||||
|
||||
def get_literal_properties(typ: type) -> Set[str]:
|
||||
def get_literal_properties(typ: type) -> set[str]:
|
||||
"Returns the names of all properties in a class that are of a literal type."
|
||||
|
||||
return set(
|
||||
return {
|
||||
property_name for property_name, property_type in get_class_properties(typ) if is_type_literal(property_type)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]:
|
||||
def get_discriminating_properties(types: tuple[type, ...]) -> set[str]:
|
||||
"Returns a set of properties with literal type that are common across all specified classes."
|
||||
|
||||
if not types or not all(isinstance(typ, type) for typ in types):
|
||||
|
@ -378,15 +373,15 @@ def get_discriminating_properties(types: Tuple[type, ...]) -> Set[str]:
|
|||
class TaggedUnionDeserializer(Deserializer):
|
||||
"De-serializes a JSON value with one or more disambiguating properties into a Python union type."
|
||||
|
||||
member_types: Tuple[type, ...]
|
||||
disambiguating_properties: Set[str]
|
||||
member_parsers: Dict[Tuple[str, Any], Deserializer]
|
||||
member_types: tuple[type, ...]
|
||||
disambiguating_properties: set[str]
|
||||
member_parsers: dict[tuple[str, Any], Deserializer]
|
||||
|
||||
def __init__(self, member_types: Tuple[type, ...]) -> None:
|
||||
def __init__(self, member_types: tuple[type, ...]) -> None:
|
||||
self.member_types = member_types
|
||||
self.disambiguating_properties = get_discriminating_properties(member_types)
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
self.member_parsers = {}
|
||||
for member_type in self.member_types:
|
||||
for property_name in self.disambiguating_properties:
|
||||
|
@ -435,13 +430,13 @@ class TaggedUnionDeserializer(Deserializer):
|
|||
class LiteralDeserializer(Deserializer):
|
||||
"De-serializes a JSON value into a Python literal type."
|
||||
|
||||
values: Tuple[Any, ...]
|
||||
values: tuple[Any, ...]
|
||||
parser: Deserializer
|
||||
|
||||
def __init__(self, values: Tuple[Any, ...]) -> None:
|
||||
def __init__(self, values: tuple[Any, ...]) -> None:
|
||||
self.values = values
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
literal_type_tuple = tuple(type(value) for value in self.values)
|
||||
literal_type_set = set(literal_type_tuple)
|
||||
if len(literal_type_set) != 1:
|
||||
|
@ -464,9 +459,9 @@ class LiteralDeserializer(Deserializer):
|
|||
class EnumDeserializer(Deserializer[E]):
|
||||
"Returns an enumeration instance based on the enumeration value read from a JSON value."
|
||||
|
||||
enum_type: Type[E]
|
||||
enum_type: type[E]
|
||||
|
||||
def __init__(self, enum_type: Type[E]) -> None:
|
||||
def __init__(self, enum_type: type[E]) -> None:
|
||||
self.enum_type = enum_type
|
||||
|
||||
def parse(self, data: JsonType) -> E:
|
||||
|
@ -504,13 +499,13 @@ class FieldDeserializer(abc.ABC, Generic[T, R]):
|
|||
self.parser = parser
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_field(self, data: Dict[str, JsonType]) -> R: ...
|
||||
def parse_field(self, data: dict[str, JsonType]) -> R: ...
|
||||
|
||||
|
||||
class RequiredFieldDeserializer(FieldDeserializer[T, T]):
|
||||
"Deserializes a JSON property into a mandatory Python object field."
|
||||
|
||||
def parse_field(self, data: Dict[str, JsonType]) -> T:
|
||||
def parse_field(self, data: dict[str, JsonType]) -> T:
|
||||
if self.property_name not in data:
|
||||
raise JsonKeyError(f"missing required property `{self.property_name}` from JSON object: {data}")
|
||||
|
||||
|
@ -520,7 +515,7 @@ class RequiredFieldDeserializer(FieldDeserializer[T, T]):
|
|||
class OptionalFieldDeserializer(FieldDeserializer[T, Optional[T]]):
|
||||
"Deserializes a JSON property into an optional Python object field with a default value of `None`."
|
||||
|
||||
def parse_field(self, data: Dict[str, JsonType]) -> Optional[T]:
|
||||
def parse_field(self, data: dict[str, JsonType]) -> T | None:
|
||||
value = data.get(self.property_name)
|
||||
if value is not None:
|
||||
return self.parser.parse(value)
|
||||
|
@ -543,7 +538,7 @@ class DefaultFieldDeserializer(FieldDeserializer[T, T]):
|
|||
super().__init__(property_name, field_name, parser)
|
||||
self.default_value = default_value
|
||||
|
||||
def parse_field(self, data: Dict[str, JsonType]) -> T:
|
||||
def parse_field(self, data: dict[str, JsonType]) -> T:
|
||||
value = data.get(self.property_name)
|
||||
if value is not None:
|
||||
return self.parser.parse(value)
|
||||
|
@ -566,7 +561,7 @@ class DefaultFactoryFieldDeserializer(FieldDeserializer[T, T]):
|
|||
super().__init__(property_name, field_name, parser)
|
||||
self.default_factory = default_factory
|
||||
|
||||
def parse_field(self, data: Dict[str, JsonType]) -> T:
|
||||
def parse_field(self, data: dict[str, JsonType]) -> T:
|
||||
value = data.get(self.property_name)
|
||||
if value is not None:
|
||||
return self.parser.parse(value)
|
||||
|
@ -578,22 +573,22 @@ class ClassDeserializer(Deserializer[T]):
|
|||
"Base class for de-serializing class-like types such as data classes, named tuples and regular classes."
|
||||
|
||||
class_type: type
|
||||
property_parsers: List[FieldDeserializer]
|
||||
property_fields: Set[str]
|
||||
property_parsers: list[FieldDeserializer]
|
||||
property_fields: set[str]
|
||||
|
||||
def __init__(self, class_type: Type[T]) -> None:
|
||||
def __init__(self, class_type: type[T]) -> None:
|
||||
self.class_type = class_type
|
||||
|
||||
def assign(self, property_parsers: List[FieldDeserializer]) -> None:
|
||||
def assign(self, property_parsers: list[FieldDeserializer]) -> None:
|
||||
self.property_parsers = property_parsers
|
||||
self.property_fields = set(property_parser.property_name for property_parser in property_parsers)
|
||||
self.property_fields = {property_parser.property_name for property_parser in property_parsers}
|
||||
|
||||
def parse(self, data: JsonType) -> T:
|
||||
if not isinstance(data, dict):
|
||||
type_name = python_type_to_str(self.class_type)
|
||||
raise JsonTypeError(f"`type `{type_name}` expects JSON `object` data but instead received: {data}")
|
||||
|
||||
object_data: Dict[str, JsonType] = typing.cast(Dict[str, JsonType], data)
|
||||
object_data: dict[str, JsonType] = typing.cast(dict[str, JsonType], data)
|
||||
|
||||
field_values = {}
|
||||
for property_parser in self.property_parsers:
|
||||
|
@ -619,8 +614,8 @@ class ClassDeserializer(Deserializer[T]):
|
|||
class NamedTupleDeserializer(ClassDeserializer[NamedTuple]):
|
||||
"De-serializes a named tuple from a JSON `object`."
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
property_parsers: List[FieldDeserializer] = [
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
property_parsers: list[FieldDeserializer] = [
|
||||
RequiredFieldDeserializer(field_name, field_name, _get_deserializer(field_type, context))
|
||||
for field_name, field_type in get_resolved_hints(self.class_type).items()
|
||||
]
|
||||
|
@ -634,13 +629,13 @@ class NamedTupleDeserializer(ClassDeserializer[NamedTuple]):
|
|||
class DataclassDeserializer(ClassDeserializer[T]):
|
||||
"De-serializes a data class from a JSON `object`."
|
||||
|
||||
def __init__(self, class_type: Type[T]) -> None:
|
||||
def __init__(self, class_type: type[T]) -> None:
|
||||
if not dataclasses.is_dataclass(class_type):
|
||||
raise TypeError("expected: data-class type")
|
||||
super().__init__(class_type) # type: ignore[arg-type]
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
property_parsers: List[FieldDeserializer] = []
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
property_parsers: list[FieldDeserializer] = []
|
||||
resolved_hints = get_resolved_hints(self.class_type)
|
||||
for field in dataclasses.fields(self.class_type):
|
||||
field_type = resolved_hints[field.name]
|
||||
|
@ -651,7 +646,7 @@ class DataclassDeserializer(ClassDeserializer[T]):
|
|||
has_default_factory = field.default_factory is not dataclasses.MISSING
|
||||
|
||||
if is_optional:
|
||||
required_type: Type[T] = unwrap_optional_type(field_type)
|
||||
required_type: type[T] = unwrap_optional_type(field_type)
|
||||
else:
|
||||
required_type = field_type
|
||||
|
||||
|
@ -691,15 +686,15 @@ class FrozenDataclassDeserializer(DataclassDeserializer[T]):
|
|||
class TypedClassDeserializer(ClassDeserializer[T]):
|
||||
"De-serializes a class with type annotations from a JSON `object` by iterating over class properties."
|
||||
|
||||
def build(self, context: Optional[ModuleType]) -> None:
|
||||
property_parsers: List[FieldDeserializer] = []
|
||||
def build(self, context: ModuleType | None) -> None:
|
||||
property_parsers: list[FieldDeserializer] = []
|
||||
for field_name, field_type in get_resolved_hints(self.class_type).items():
|
||||
property_name = python_field_to_json_property(field_name, field_type)
|
||||
|
||||
is_optional = is_type_optional(field_type)
|
||||
|
||||
if is_optional:
|
||||
required_type: Type[T] = unwrap_optional_type(field_type)
|
||||
required_type: type[T] = unwrap_optional_type(field_type)
|
||||
else:
|
||||
required_type = field_type
|
||||
|
||||
|
@ -715,7 +710,7 @@ class TypedClassDeserializer(ClassDeserializer[T]):
|
|||
super().assign(property_parsers)
|
||||
|
||||
|
||||
def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Deserializer:
|
||||
def create_deserializer(typ: TypeLike, context: ModuleType | None = None) -> Deserializer:
|
||||
"""
|
||||
Creates a de-serializer engine to produce a Python object from an object obtained from a JSON string.
|
||||
|
||||
|
@ -741,15 +736,15 @@ def create_deserializer(typ: TypeLike, context: Optional[ModuleType] = None) ->
|
|||
return _get_deserializer(typ, context)
|
||||
|
||||
|
||||
_CACHE: Dict[Tuple[str, str], Deserializer] = {}
|
||||
_CACHE: dict[tuple[str, str], Deserializer] = {}
|
||||
|
||||
|
||||
def _get_deserializer(typ: TypeLike, context: Optional[ModuleType]) -> Deserializer:
|
||||
def _get_deserializer(typ: TypeLike, context: ModuleType | None) -> Deserializer:
|
||||
"Creates or re-uses a de-serializer engine to parse an object obtained from a JSON string."
|
||||
|
||||
cache_key = None
|
||||
|
||||
if isinstance(typ, (str, typing.ForwardRef)):
|
||||
if isinstance(typ, str | typing.ForwardRef):
|
||||
if context is None:
|
||||
raise TypeError(f"missing context for evaluating type: {typ}")
|
||||
|
||||
|
|
|
@ -15,17 +15,12 @@ import collections.abc
|
|||
import dataclasses
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from io import StringIO
|
||||
from typing import Any, Callable, Dict, Optional, Protocol, Type, TypeVar
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeGuard
|
||||
else:
|
||||
from typing_extensions import TypeGuard
|
||||
from typing import Any, Protocol, TypeGuard, TypeVar
|
||||
|
||||
from .inspection import (
|
||||
DataclassInstance,
|
||||
|
@ -110,14 +105,14 @@ class Docstring:
|
|||
:param returns: The returns declaration extracted from a docstring.
|
||||
"""
|
||||
|
||||
short_description: Optional[str] = None
|
||||
long_description: Optional[str] = None
|
||||
params: Dict[str, DocstringParam] = dataclasses.field(default_factory=dict)
|
||||
returns: Optional[DocstringReturns] = None
|
||||
raises: Dict[str, DocstringRaises] = dataclasses.field(default_factory=dict)
|
||||
short_description: str | None = None
|
||||
long_description: str | None = None
|
||||
params: dict[str, DocstringParam] = dataclasses.field(default_factory=dict)
|
||||
returns: DocstringReturns | None = None
|
||||
raises: dict[str, DocstringRaises] = dataclasses.field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def full_description(self) -> Optional[str]:
|
||||
def full_description(self) -> str | None:
|
||||
if self.short_description and self.long_description:
|
||||
return f"{self.short_description}\n\n{self.long_description}"
|
||||
elif self.short_description:
|
||||
|
@ -158,18 +153,18 @@ class Docstring:
|
|||
return s
|
||||
|
||||
|
||||
def is_exception(member: object) -> TypeGuard[Type[BaseException]]:
|
||||
def is_exception(member: object) -> TypeGuard[type[BaseException]]:
|
||||
return isinstance(member, type) and issubclass(member, BaseException)
|
||||
|
||||
|
||||
def get_exceptions(module: types.ModuleType) -> Dict[str, Type[BaseException]]:
|
||||
def get_exceptions(module: types.ModuleType) -> dict[str, type[BaseException]]:
|
||||
"Returns all exception classes declared in a module."
|
||||
|
||||
return {name: class_type for name, class_type in inspect.getmembers(module, is_exception)}
|
||||
return dict(inspect.getmembers(module, is_exception))
|
||||
|
||||
|
||||
class SupportsDoc(Protocol):
|
||||
__doc__: Optional[str]
|
||||
__doc__: str | None
|
||||
|
||||
|
||||
def _maybe_unwrap_async_iterator(t):
|
||||
|
@ -213,7 +208,7 @@ def parse_type(typ: SupportsDoc) -> Docstring:
|
|||
# assign exception types
|
||||
defining_module = inspect.getmodule(typ)
|
||||
if defining_module:
|
||||
context: Dict[str, type] = {}
|
||||
context: dict[str, type] = {}
|
||||
context.update(get_exceptions(builtins))
|
||||
context.update(get_exceptions(defining_module))
|
||||
for exc_name, exc in docstring.raises.items():
|
||||
|
@ -262,8 +257,8 @@ def parse_text(text: str) -> Docstring:
|
|||
else:
|
||||
long_description = None
|
||||
|
||||
params: Dict[str, DocstringParam] = {}
|
||||
raises: Dict[str, DocstringRaises] = {}
|
||||
params: dict[str, DocstringParam] = {}
|
||||
raises: dict[str, DocstringRaises] = {}
|
||||
returns = None
|
||||
for match in re.finditer(r"(^:.*?)(?=^:|\Z)", meta_chunk, flags=re.DOTALL | re.MULTILINE):
|
||||
chunk = match.group(0)
|
||||
|
@ -325,7 +320,7 @@ def has_docstring(typ: SupportsDoc) -> bool:
|
|||
return bool(typ.__doc__)
|
||||
|
||||
|
||||
def get_docstring(typ: SupportsDoc) -> Optional[str]:
|
||||
def get_docstring(typ: SupportsDoc) -> str | None:
|
||||
if typ.__doc__ is None:
|
||||
return None
|
||||
|
||||
|
@ -348,7 +343,7 @@ def check_docstring(typ: SupportsDoc, docstring: Docstring, strict: bool = False
|
|||
check_function_docstring(typ, docstring, strict)
|
||||
|
||||
|
||||
def check_dataclass_docstring(typ: Type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None:
|
||||
def check_dataclass_docstring(typ: type[DataclassInstance], docstring: Docstring, strict: bool = False) -> None:
|
||||
"""
|
||||
Verifies the doc-string of a data-class type.
|
||||
|
||||
|
|
|
@ -22,34 +22,19 @@ import sys
|
|||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeGuard,
|
||||
TypeVar,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
else:
|
||||
from typing_extensions import Annotated
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeGuard
|
||||
else:
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
S = TypeVar("S")
|
||||
T = TypeVar("T")
|
||||
K = TypeVar("K")
|
||||
|
@ -80,28 +65,20 @@ def _is_type_like(data_type: object) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
TypeLike = Union[type, types.GenericAlias, typing.ForwardRef, Any]
|
||||
TypeLike = Union[type, types.GenericAlias, typing.ForwardRef, Any]
|
||||
|
||||
def is_type_like(
|
||||
data_type: object,
|
||||
) -> TypeGuard[TypeLike]:
|
||||
"""
|
||||
Checks if the object is a type or type-like object (e.g. generic type).
|
||||
|
||||
:param data_type: The object to validate.
|
||||
:returns: True if the object is a type or type-like object.
|
||||
"""
|
||||
def is_type_like(
|
||||
data_type: object,
|
||||
) -> TypeGuard[TypeLike]:
|
||||
"""
|
||||
Checks if the object is a type or type-like object (e.g. generic type).
|
||||
|
||||
return _is_type_like(data_type)
|
||||
:param data_type: The object to validate.
|
||||
:returns: True if the object is a type or type-like object.
|
||||
"""
|
||||
|
||||
else:
|
||||
TypeLike = object
|
||||
|
||||
def is_type_like(
|
||||
data_type: object,
|
||||
) -> bool:
|
||||
return _is_type_like(data_type)
|
||||
return _is_type_like(data_type)
|
||||
|
||||
|
||||
def evaluate_member_type(typ: Any, cls: type) -> Any:
|
||||
|
@ -129,20 +106,17 @@ def evaluate_type(typ: Any, module: types.ModuleType) -> Any:
|
|||
# evaluate data-class field whose type annotation is a string
|
||||
return eval(typ, module.__dict__, locals())
|
||||
if isinstance(typ, typing.ForwardRef):
|
||||
if sys.version_info >= (3, 9):
|
||||
return typ._evaluate(module.__dict__, locals(), recursive_guard=frozenset())
|
||||
else:
|
||||
return typ._evaluate(module.__dict__, locals())
|
||||
return typ._evaluate(module.__dict__, locals(), recursive_guard=frozenset())
|
||||
else:
|
||||
return typ
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DataclassInstance(Protocol):
|
||||
__dataclass_fields__: typing.ClassVar[Dict[str, dataclasses.Field]]
|
||||
__dataclass_fields__: typing.ClassVar[dict[str, dataclasses.Field]]
|
||||
|
||||
|
||||
def is_dataclass_type(typ: Any) -> TypeGuard[Type[DataclassInstance]]:
|
||||
def is_dataclass_type(typ: Any) -> TypeGuard[type[DataclassInstance]]:
|
||||
"True if the argument corresponds to a data class type (but not an instance)."
|
||||
|
||||
typ = unwrap_annotated_type(typ)
|
||||
|
@ -167,14 +141,14 @@ class DataclassField:
|
|||
self.default = default
|
||||
|
||||
|
||||
def dataclass_fields(cls: Type[DataclassInstance]) -> Iterable[DataclassField]:
|
||||
def dataclass_fields(cls: type[DataclassInstance]) -> Iterable[DataclassField]:
|
||||
"Generates the fields of a data-class resolving forward references."
|
||||
|
||||
for field in dataclasses.fields(cls):
|
||||
yield DataclassField(field.name, evaluate_member_type(field.type, cls), field.default)
|
||||
|
||||
|
||||
def dataclass_field_by_name(cls: Type[DataclassInstance], name: str) -> DataclassField:
|
||||
def dataclass_field_by_name(cls: type[DataclassInstance], name: str) -> DataclassField:
|
||||
"Looks up a field in a data-class by its field name."
|
||||
|
||||
for field in dataclasses.fields(cls):
|
||||
|
@ -190,7 +164,7 @@ def is_named_tuple_instance(obj: Any) -> TypeGuard[NamedTuple]:
|
|||
return is_named_tuple_type(type(obj))
|
||||
|
||||
|
||||
def is_named_tuple_type(typ: Any) -> TypeGuard[Type[NamedTuple]]:
|
||||
def is_named_tuple_type(typ: Any) -> TypeGuard[type[NamedTuple]]:
|
||||
"""
|
||||
True if the argument corresponds to a named tuple type.
|
||||
|
||||
|
@ -217,26 +191,14 @@ def is_named_tuple_type(typ: Any) -> TypeGuard[Type[NamedTuple]]:
|
|||
return all(isinstance(n, str) for n in f)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
def is_type_enum(typ: object) -> TypeGuard[type[enum.Enum]]:
|
||||
"True if the specified type is an enumeration type."
|
||||
|
||||
def is_type_enum(typ: object) -> TypeGuard[Type[enum.Enum]]:
|
||||
"True if the specified type is an enumeration type."
|
||||
|
||||
typ = unwrap_annotated_type(typ)
|
||||
return isinstance(typ, enum.EnumType)
|
||||
|
||||
else:
|
||||
|
||||
def is_type_enum(typ: object) -> TypeGuard[Type[enum.Enum]]:
|
||||
"True if the specified type is an enumeration type."
|
||||
|
||||
typ = unwrap_annotated_type(typ)
|
||||
|
||||
# use an explicit isinstance(..., type) check to filter out special forms like generics
|
||||
return isinstance(typ, type) and issubclass(typ, enum.Enum)
|
||||
typ = unwrap_annotated_type(typ)
|
||||
return isinstance(typ, enum.EnumType)
|
||||
|
||||
|
||||
def enum_value_types(enum_type: Type[enum.Enum]) -> List[type]:
|
||||
def enum_value_types(enum_type: type[enum.Enum]) -> list[type]:
|
||||
"""
|
||||
Returns all unique value types of the `enum.Enum` type in definition order.
|
||||
"""
|
||||
|
@ -246,8 +208,8 @@ def enum_value_types(enum_type: Type[enum.Enum]) -> List[type]:
|
|||
|
||||
|
||||
def extend_enum(
|
||||
source: Type[enum.Enum],
|
||||
) -> Callable[[Type[enum.Enum]], Type[enum.Enum]]:
|
||||
source: type[enum.Enum],
|
||||
) -> Callable[[type[enum.Enum]], type[enum.Enum]]:
|
||||
"""
|
||||
Creates a new enumeration type extending the set of values in an existing type.
|
||||
|
||||
|
@ -255,13 +217,13 @@ def extend_enum(
|
|||
:returns: A new enumeration type with the extended set of values.
|
||||
"""
|
||||
|
||||
def wrap(extend: Type[enum.Enum]) -> Type[enum.Enum]:
|
||||
def wrap(extend: type[enum.Enum]) -> type[enum.Enum]:
|
||||
# create new enumeration type combining the values from both types
|
||||
values: Dict[str, Any] = {}
|
||||
values: dict[str, Any] = {}
|
||||
values.update((e.name, e.value) for e in source)
|
||||
values.update((e.name, e.value) for e in extend)
|
||||
# mypy fails to determine that __name__ is always a string; hence the `ignore` directive.
|
||||
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc]
|
||||
enum_class: type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc]
|
||||
|
||||
# assign the newly created type to the same module where the extending class is defined
|
||||
enum_class.__module__ = extend.__module__
|
||||
|
@ -273,22 +235,13 @@ def extend_enum(
|
|||
return wrap
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
def _is_union_like(typ: object) -> bool:
|
||||
"True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`."
|
||||
|
||||
def _is_union_like(typ: object) -> bool:
|
||||
"True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`."
|
||||
|
||||
return typing.get_origin(typ) is Union or isinstance(typ, types.UnionType)
|
||||
|
||||
else:
|
||||
|
||||
def _is_union_like(typ: object) -> bool:
|
||||
"True if type is a union such as `Union[T1, T2, ...]` or a union type `T1 | T2`."
|
||||
|
||||
return typing.get_origin(typ) is Union
|
||||
return typing.get_origin(typ) is Union or isinstance(typ, types.UnionType)
|
||||
|
||||
|
||||
def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[Type[Optional[Any]]]:
|
||||
def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[type[Any | None]]:
|
||||
"""
|
||||
True if the type annotation corresponds to an optional type (e.g. `Optional[T]` or `Union[T1,T2,None]`).
|
||||
|
||||
|
@ -309,7 +262,7 @@ def is_type_optional(typ: object, strict: bool = False) -> TypeGuard[Type[Option
|
|||
return False
|
||||
|
||||
|
||||
def unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
|
||||
def unwrap_optional_type(typ: type[T | None]) -> type[T]:
|
||||
"""
|
||||
Extracts the inner type of an optional type.
|
||||
|
||||
|
@ -320,7 +273,7 @@ def unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
|
|||
return rewrap_annotated_type(_unwrap_optional_type, typ)
|
||||
|
||||
|
||||
def _unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
|
||||
def _unwrap_optional_type(typ: type[T | None]) -> type[T]:
|
||||
"Extracts the type qualified as optional (e.g. returns `T` for `Optional[T]`)."
|
||||
|
||||
# Optional[T] is represented internally as Union[T, None]
|
||||
|
@ -342,7 +295,7 @@ def is_type_union(typ: object) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def unwrap_union_types(typ: object) -> Tuple[object, ...]:
|
||||
def unwrap_union_types(typ: object) -> tuple[object, ...]:
|
||||
"""
|
||||
Extracts the inner types of a union type.
|
||||
|
||||
|
@ -354,7 +307,7 @@ def unwrap_union_types(typ: object) -> Tuple[object, ...]:
|
|||
return _unwrap_union_types(typ)
|
||||
|
||||
|
||||
def _unwrap_union_types(typ: object) -> Tuple[object, ...]:
|
||||
def _unwrap_union_types(typ: object) -> tuple[object, ...]:
|
||||
"Extracts the types in a union (e.g. returns a tuple of types `T1` and `T2` for `Union[T1, T2]`)."
|
||||
|
||||
if not _is_union_like(typ):
|
||||
|
@ -385,7 +338,7 @@ def unwrap_literal_value(typ: object) -> Any:
|
|||
return args[0]
|
||||
|
||||
|
||||
def unwrap_literal_values(typ: object) -> Tuple[Any, ...]:
|
||||
def unwrap_literal_values(typ: object) -> tuple[Any, ...]:
|
||||
"""
|
||||
Extracts the constant values captured by a literal type.
|
||||
|
||||
|
@ -397,7 +350,7 @@ def unwrap_literal_values(typ: object) -> Tuple[Any, ...]:
|
|||
return typing.get_args(typ)
|
||||
|
||||
|
||||
def unwrap_literal_types(typ: object) -> Tuple[type, ...]:
|
||||
def unwrap_literal_types(typ: object) -> tuple[type, ...]:
|
||||
"""
|
||||
Extracts the types of the constant values captured by a literal type.
|
||||
|
||||
|
@ -408,14 +361,14 @@ def unwrap_literal_types(typ: object) -> Tuple[type, ...]:
|
|||
return tuple(type(t) for t in unwrap_literal_values(typ))
|
||||
|
||||
|
||||
def is_generic_list(typ: object) -> TypeGuard[Type[list]]:
|
||||
def is_generic_list(typ: object) -> TypeGuard[type[list]]:
|
||||
"True if the specified type is a generic list, i.e. `List[T]`."
|
||||
|
||||
typ = unwrap_annotated_type(typ)
|
||||
return typing.get_origin(typ) is list
|
||||
|
||||
|
||||
def unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
|
||||
def unwrap_generic_list(typ: type[list[T]]) -> type[T]:
|
||||
"""
|
||||
Extracts the item type of a list type.
|
||||
|
||||
|
@ -426,21 +379,21 @@ def unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
|
|||
return rewrap_annotated_type(_unwrap_generic_list, typ)
|
||||
|
||||
|
||||
def _unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
|
||||
def _unwrap_generic_list(typ: type[list[T]]) -> type[T]:
|
||||
"Extracts the item type of a list type (e.g. returns `T` for `List[T]`)."
|
||||
|
||||
(list_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return list_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_generic_set(typ: object) -> TypeGuard[Type[set]]:
|
||||
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
|
||||
"True if the specified type is a generic set, i.e. `Set[T]`."
|
||||
|
||||
typ = unwrap_annotated_type(typ)
|
||||
return typing.get_origin(typ) is set
|
||||
|
||||
|
||||
def unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
|
||||
def unwrap_generic_set(typ: type[set[T]]) -> type[T]:
|
||||
"""
|
||||
Extracts the item type of a set type.
|
||||
|
||||
|
@ -451,21 +404,21 @@ def unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
|
|||
return rewrap_annotated_type(_unwrap_generic_set, typ)
|
||||
|
||||
|
||||
def _unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
|
||||
def _unwrap_generic_set(typ: type[set[T]]) -> type[T]:
|
||||
"Extracts the item type of a set type (e.g. returns `T` for `Set[T]`)."
|
||||
|
||||
(set_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return set_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_generic_dict(typ: object) -> TypeGuard[Type[dict]]:
|
||||
def is_generic_dict(typ: object) -> TypeGuard[type[dict]]:
|
||||
"True if the specified type is a generic dictionary, i.e. `Dict[KeyType, ValueType]`."
|
||||
|
||||
typ = unwrap_annotated_type(typ)
|
||||
return typing.get_origin(typ) is dict
|
||||
|
||||
|
||||
def unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]:
|
||||
def unwrap_generic_dict(typ: type[dict[K, V]]) -> tuple[type[K], type[V]]:
|
||||
"""
|
||||
Extracts the key and value types of a dictionary type as a tuple.
|
||||
|
||||
|
@ -476,7 +429,7 @@ def unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]:
|
|||
return _unwrap_generic_dict(unwrap_annotated_type(typ))
|
||||
|
||||
|
||||
def _unwrap_generic_dict(typ: Type[Dict[K, V]]) -> Tuple[Type[K], Type[V]]:
|
||||
def _unwrap_generic_dict(typ: type[dict[K, V]]) -> tuple[type[K], type[V]]:
|
||||
"Extracts the key and value types of a dict type (e.g. returns (`K`, `V`) for `Dict[K, V]`)."
|
||||
|
||||
key_type, value_type = typing.get_args(typ)
|
||||
|
@ -489,7 +442,7 @@ def is_type_annotated(typ: TypeLike) -> bool:
|
|||
return getattr(typ, "__metadata__", None) is not None
|
||||
|
||||
|
||||
def get_annotation(data_type: TypeLike, annotation_type: Type[T]) -> Optional[T]:
|
||||
def get_annotation(data_type: TypeLike, annotation_type: type[T]) -> T | None:
|
||||
"""
|
||||
Returns the first annotation on a data type that matches the expected annotation type.
|
||||
|
||||
|
@ -518,7 +471,7 @@ def unwrap_annotated_type(typ: T) -> T:
|
|||
return typ
|
||||
|
||||
|
||||
def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S]) -> Type[T]:
|
||||
def rewrap_annotated_type(transform: Callable[[type[S]], type[T]], typ: type[S]) -> type[T]:
|
||||
"""
|
||||
Un-boxes, transforms and re-boxes an optionally annotated type.
|
||||
|
||||
|
@ -542,7 +495,7 @@ def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S])
|
|||
return transformed_type
|
||||
|
||||
|
||||
def get_module_classes(module: types.ModuleType) -> List[type]:
|
||||
def get_module_classes(module: types.ModuleType) -> list[type]:
|
||||
"Returns all classes declared directly in a module."
|
||||
|
||||
def is_class_member(member: object) -> TypeGuard[type]:
|
||||
|
@ -551,18 +504,11 @@ def get_module_classes(module: types.ModuleType) -> List[type]:
|
|||
return [class_type for _, class_type in inspect.getmembers(module, is_class_member)]
|
||||
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
|
||||
def get_resolved_hints(typ: type) -> Dict[str, type]:
|
||||
return typing.get_type_hints(typ, include_extras=True)
|
||||
|
||||
else:
|
||||
|
||||
def get_resolved_hints(typ: type) -> Dict[str, type]:
|
||||
return typing.get_type_hints(typ)
|
||||
def get_resolved_hints(typ: type) -> dict[str, type]:
|
||||
return typing.get_type_hints(typ, include_extras=True)
|
||||
|
||||
|
||||
def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
|
||||
def get_class_properties(typ: type) -> Iterable[tuple[str, type | str]]:
|
||||
"Returns all properties of a class."
|
||||
|
||||
if is_dataclass_type(typ):
|
||||
|
@ -572,7 +518,7 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
|
|||
return resolved_hints.items()
|
||||
|
||||
|
||||
def get_class_property(typ: type, name: str) -> Optional[type | str]:
|
||||
def get_class_property(typ: type, name: str) -> type | str | None:
|
||||
"Looks up the annotated type of a property in a class by its property name."
|
||||
|
||||
for property_name, property_type in get_class_properties(typ):
|
||||
|
@ -586,7 +532,7 @@ class _ROOT:
|
|||
pass
|
||||
|
||||
|
||||
def get_referenced_types(typ: TypeLike, module: Optional[types.ModuleType] = None) -> Set[type]:
|
||||
def get_referenced_types(typ: TypeLike, module: types.ModuleType | None = None) -> set[type]:
|
||||
"""
|
||||
Extracts types directly or indirectly referenced by this type.
|
||||
|
||||
|
@ -610,10 +556,10 @@ class TypeCollector:
|
|||
:param graph: The type dependency graph, linking types to types they depend on.
|
||||
"""
|
||||
|
||||
graph: Dict[type, Set[type]]
|
||||
graph: dict[type, set[type]]
|
||||
|
||||
@property
|
||||
def references(self) -> Set[type]:
|
||||
def references(self) -> set[type]:
|
||||
"Types collected by the type collector."
|
||||
|
||||
dependencies = set()
|
||||
|
@ -638,8 +584,8 @@ class TypeCollector:
|
|||
def run(
|
||||
self,
|
||||
typ: TypeLike,
|
||||
cls: Type[DataclassInstance],
|
||||
module: Optional[types.ModuleType],
|
||||
cls: type[DataclassInstance],
|
||||
module: types.ModuleType | None,
|
||||
) -> None:
|
||||
"""
|
||||
Extracts types indirectly referenced by this type.
|
||||
|
@ -702,26 +648,17 @@ class TypeCollector:
|
|||
for field in dataclass_fields(typ):
|
||||
self.run(field.type, typ, context)
|
||||
else:
|
||||
for field_name, field_type in get_resolved_hints(typ).items():
|
||||
for _field_name, field_type in get_resolved_hints(typ).items():
|
||||
self.run(field_type, typ, context)
|
||||
return
|
||||
|
||||
raise TypeError(f"expected: type-like; got: {typ}")
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
def get_signature(fn: Callable[..., Any]) -> inspect.Signature:
|
||||
"Extracts the signature of a function."
|
||||
|
||||
def get_signature(fn: Callable[..., Any]) -> inspect.Signature:
|
||||
"Extracts the signature of a function."
|
||||
|
||||
return inspect.signature(fn, eval_str=True)
|
||||
|
||||
else:
|
||||
|
||||
def get_signature(fn: Callable[..., Any]) -> inspect.Signature:
|
||||
"Extracts the signature of a function."
|
||||
|
||||
return inspect.signature(fn)
|
||||
return inspect.signature(fn, eval_str=True)
|
||||
|
||||
|
||||
def is_reserved_property(name: str) -> bool:
|
||||
|
@ -756,51 +693,20 @@ def create_module(name: str) -> types.ModuleType:
|
|||
return module
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
def create_data_type(class_name: str, fields: list[tuple[str, type]]) -> type:
|
||||
"""
|
||||
Creates a new data-class type dynamically.
|
||||
|
||||
def create_data_type(class_name: str, fields: List[Tuple[str, type]]) -> type:
|
||||
"""
|
||||
Creates a new data-class type dynamically.
|
||||
:param class_name: The name of new data-class type.
|
||||
:param fields: A list of fields (and their type) that the new data-class type is expected to have.
|
||||
:returns: The newly created data-class type.
|
||||
"""
|
||||
|
||||
:param class_name: The name of new data-class type.
|
||||
:param fields: A list of fields (and their type) that the new data-class type is expected to have.
|
||||
:returns: The newly created data-class type.
|
||||
"""
|
||||
|
||||
# has the `slots` parameter
|
||||
return dataclasses.make_dataclass(class_name, fields, slots=True)
|
||||
|
||||
else:
|
||||
|
||||
def create_data_type(class_name: str, fields: List[Tuple[str, type]]) -> type:
|
||||
"""
|
||||
Creates a new data-class type dynamically.
|
||||
|
||||
:param class_name: The name of new data-class type.
|
||||
:param fields: A list of fields (and their type) that the new data-class type is expected to have.
|
||||
:returns: The newly created data-class type.
|
||||
"""
|
||||
|
||||
cls = dataclasses.make_dataclass(class_name, fields)
|
||||
|
||||
cls_dict = dict(cls.__dict__)
|
||||
field_names = tuple(field.name for field in dataclasses.fields(cls))
|
||||
|
||||
cls_dict["__slots__"] = field_names
|
||||
|
||||
for field_name in field_names:
|
||||
cls_dict.pop(field_name, None)
|
||||
cls_dict.pop("__dict__", None)
|
||||
|
||||
qualname = getattr(cls, "__qualname__", None)
|
||||
cls = type(cls)(cls.__name__, (), cls_dict)
|
||||
if qualname is not None:
|
||||
cls.__qualname__ = qualname
|
||||
|
||||
return cls
|
||||
# has the `slots` parameter
|
||||
return dataclasses.make_dataclass(class_name, fields, slots=True)
|
||||
|
||||
|
||||
def create_object(typ: Type[T]) -> T:
|
||||
def create_object(typ: type[T]) -> T:
|
||||
"Creates an instance of a type."
|
||||
|
||||
if issubclass(typ, Exception):
|
||||
|
@ -811,11 +717,7 @@ def create_object(typ: Type[T]) -> T:
|
|||
return object.__new__(typ)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
TypeOrGeneric = Union[type, types.GenericAlias]
|
||||
|
||||
else:
|
||||
TypeOrGeneric = object
|
||||
TypeOrGeneric = Union[type, types.GenericAlias]
|
||||
|
||||
|
||||
def is_generic_instance(obj: Any, typ: TypeLike) -> bool:
|
||||
|
@ -885,7 +787,7 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool:
|
|||
|
||||
|
||||
class RecursiveChecker:
|
||||
_pred: Optional[Callable[[type, Any], bool]]
|
||||
_pred: Callable[[type, Any], bool] | None
|
||||
|
||||
def __init__(self, pred: Callable[[type, Any], bool]) -> None:
|
||||
"""
|
||||
|
@ -997,9 +899,9 @@ def check_recursive(
|
|||
obj: object,
|
||||
/,
|
||||
*,
|
||||
pred: Optional[Callable[[type, Any], bool]] = None,
|
||||
type_pred: Optional[Callable[[type], bool]] = None,
|
||||
value_pred: Optional[Callable[[Any], bool]] = None,
|
||||
pred: Callable[[type, Any], bool] | None = None,
|
||||
type_pred: Callable[[type], bool] | None = None,
|
||||
value_pred: Callable[[Any], bool] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Checks if a predicate applies to all nested member properties of an object recursively.
|
||||
|
@ -1015,7 +917,7 @@ def check_recursive(
|
|||
if pred is not None:
|
||||
raise TypeError("filter predicate not permitted when type and value predicates are present")
|
||||
|
||||
type_p: Callable[[Type[T]], bool] = type_pred
|
||||
type_p: Callable[[type[T]], bool] = type_pred
|
||||
value_p: Callable[[T], bool] = value_pred
|
||||
pred = lambda typ, obj: not type_p(typ) or value_p(obj) # noqa: E731
|
||||
|
||||
|
|
|
@ -11,13 +11,12 @@ Type-safe data interchange for Python data classes.
|
|||
"""
|
||||
|
||||
import keyword
|
||||
from typing import Optional
|
||||
|
||||
from .auxiliary import Alias
|
||||
from .inspection import get_annotation
|
||||
|
||||
|
||||
def python_field_to_json_property(python_id: str, python_type: Optional[object] = None) -> str:
|
||||
def python_field_to_json_property(python_id: str, python_type: object | None = None) -> str:
|
||||
"""
|
||||
Map a Python field identifier to a JSON property name.
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ Type-safe data interchange for Python data classes.
|
|||
"""
|
||||
|
||||
import typing
|
||||
from typing import Any, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from .auxiliary import _auxiliary_types
|
||||
from .inspection import (
|
||||
|
@ -39,7 +39,7 @@ class TypeFormatter:
|
|||
def __init__(self, use_union_operator: bool = False) -> None:
|
||||
self.use_union_operator = use_union_operator
|
||||
|
||||
def union_to_str(self, data_type_args: Tuple[TypeLike, ...]) -> str:
|
||||
def union_to_str(self, data_type_args: tuple[TypeLike, ...]) -> str:
|
||||
if self.use_union_operator:
|
||||
return " | ".join(self.python_type_to_str(t) for t in data_type_args)
|
||||
else:
|
||||
|
@ -100,7 +100,7 @@ class TypeFormatter:
|
|||
metadata = getattr(data_type, "__metadata__", None)
|
||||
if metadata is not None:
|
||||
# type is Annotated[T, ...]
|
||||
metatuple: Tuple[Any, ...] = metadata
|
||||
metatuple: tuple[Any, ...] = metadata
|
||||
arg = typing.get_args(data_type)[0]
|
||||
|
||||
# check for auxiliary types with user-defined annotations
|
||||
|
@ -110,7 +110,7 @@ class TypeFormatter:
|
|||
if arg is not auxiliary_arg:
|
||||
continue
|
||||
|
||||
auxiliary_metatuple: Optional[Tuple[Any, ...]] = getattr(auxiliary_type, "__metadata__", None)
|
||||
auxiliary_metatuple: tuple[Any, ...] | None = getattr(auxiliary_type, "__metadata__", None)
|
||||
if auxiliary_metatuple is None:
|
||||
continue
|
||||
|
||||
|
|
|
@ -21,24 +21,19 @@ import json
|
|||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import jsonschema
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from . import docstring
|
||||
from .auxiliary import (
|
||||
|
@ -71,7 +66,7 @@ OBJECT_ENUM_EXPANSION_LIMIT = 4
|
|||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_class_docstrings(data_type: type) -> Tuple[Optional[str], Optional[str]]:
|
||||
def get_class_docstrings(data_type: type) -> tuple[str | None, str | None]:
|
||||
docstr = docstring.parse_type(data_type)
|
||||
|
||||
# check if class has a doc-string other than the auto-generated string assigned by @dataclass
|
||||
|
@ -82,8 +77,8 @@ def get_class_docstrings(data_type: type) -> Tuple[Optional[str], Optional[str]]
|
|||
|
||||
|
||||
def get_class_property_docstrings(
|
||||
data_type: type, transform_fun: Optional[Callable[[type, str, str], str]] = None
|
||||
) -> Dict[str, str]:
|
||||
data_type: type, transform_fun: Callable[[type, str, str], str] | None = None
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Extracts the documentation strings associated with the properties of a composite type.
|
||||
|
||||
|
@ -120,7 +115,7 @@ def docstring_to_schema(data_type: type) -> Schema:
|
|||
return schema
|
||||
|
||||
|
||||
def id_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> str:
|
||||
def id_from_ref(data_type: typing.ForwardRef | str | type) -> str:
|
||||
"Extracts the name of a possibly forward-referenced type."
|
||||
|
||||
if isinstance(data_type, typing.ForwardRef):
|
||||
|
@ -132,7 +127,7 @@ def id_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> str:
|
|||
return data_type.__name__
|
||||
|
||||
|
||||
def type_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> Tuple[str, type]:
|
||||
def type_from_ref(data_type: typing.ForwardRef | str | type) -> tuple[str, type]:
|
||||
"Creates a type from a forward reference."
|
||||
|
||||
if isinstance(data_type, typing.ForwardRef):
|
||||
|
@ -148,16 +143,16 @@ def type_from_ref(data_type: Union[typing.ForwardRef, str, type]) -> Tuple[str,
|
|||
|
||||
@dataclasses.dataclass
|
||||
class TypeCatalogEntry:
|
||||
schema: Optional[Schema]
|
||||
schema: Schema | None
|
||||
identifier: str
|
||||
examples: Optional[JsonType] = None
|
||||
examples: JsonType | None = None
|
||||
|
||||
|
||||
class TypeCatalog:
|
||||
"Maintains an association of well-known Python types to their JSON schema."
|
||||
|
||||
_by_type: Dict[TypeLike, TypeCatalogEntry]
|
||||
_by_name: Dict[str, TypeCatalogEntry]
|
||||
_by_type: dict[TypeLike, TypeCatalogEntry]
|
||||
_by_name: dict[str, TypeCatalogEntry]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._by_type = {}
|
||||
|
@ -174,9 +169,9 @@ class TypeCatalog:
|
|||
def add(
|
||||
self,
|
||||
data_type: TypeLike,
|
||||
schema: Optional[Schema],
|
||||
schema: Schema | None,
|
||||
identifier: str,
|
||||
examples: Optional[List[JsonType]] = None,
|
||||
examples: list[JsonType] | None = None,
|
||||
) -> None:
|
||||
if isinstance(data_type, typing.ForwardRef):
|
||||
raise TypeError("forward references cannot be used to register a type")
|
||||
|
@ -202,17 +197,17 @@ class SchemaOptions:
|
|||
definitions_path: str = "#/definitions/"
|
||||
use_descriptions: bool = True
|
||||
use_examples: bool = True
|
||||
property_description_fun: Optional[Callable[[type, str, str], str]] = None
|
||||
property_description_fun: Callable[[type, str, str], str] | None = None
|
||||
|
||||
|
||||
class JsonSchemaGenerator:
|
||||
"Creates a JSON schema with user-defined type definitions."
|
||||
|
||||
type_catalog: ClassVar[TypeCatalog] = TypeCatalog()
|
||||
types_used: Dict[str, TypeLike]
|
||||
types_used: dict[str, TypeLike]
|
||||
options: SchemaOptions
|
||||
|
||||
def __init__(self, options: Optional[SchemaOptions] = None):
|
||||
def __init__(self, options: SchemaOptions | None = None):
|
||||
if options is None:
|
||||
self.options = SchemaOptions()
|
||||
else:
|
||||
|
@ -244,13 +239,13 @@ class JsonSchemaGenerator:
|
|||
def _(self, arg: MaxLength) -> Schema:
|
||||
return {"maxLength": arg.value}
|
||||
|
||||
def _with_metadata(self, type_schema: Schema, metadata: Optional[Tuple[Any, ...]]) -> Schema:
|
||||
def _with_metadata(self, type_schema: Schema, metadata: tuple[Any, ...] | None) -> Schema:
|
||||
if metadata:
|
||||
for m in metadata:
|
||||
type_schema.update(self._metadata_to_schema(m))
|
||||
return type_schema
|
||||
|
||||
def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: Optional[dict] = None) -> Optional[Schema]:
|
||||
def _simple_type_to_schema(self, typ: TypeLike, json_schema_extra: dict | None = None) -> Schema | None:
|
||||
"""
|
||||
Returns the JSON schema associated with a simple, unrestricted type.
|
||||
|
||||
|
@ -314,7 +309,7 @@ class JsonSchemaGenerator:
|
|||
self,
|
||||
data_type: TypeLike,
|
||||
force_expand: bool = False,
|
||||
json_schema_extra: Optional[dict] = None,
|
||||
json_schema_extra: dict | None = None,
|
||||
) -> Schema:
|
||||
common_info = {}
|
||||
if json_schema_extra and "deprecated" in json_schema_extra:
|
||||
|
@ -325,7 +320,7 @@ class JsonSchemaGenerator:
|
|||
self,
|
||||
data_type: TypeLike,
|
||||
force_expand: bool = False,
|
||||
json_schema_extra: Optional[dict] = None,
|
||||
json_schema_extra: dict | None = None,
|
||||
) -> Schema:
|
||||
"""
|
||||
Returns the JSON schema associated with a type.
|
||||
|
@ -381,7 +376,7 @@ class JsonSchemaGenerator:
|
|||
return {"$ref": f"{self.options.definitions_path}{identifier}"}
|
||||
|
||||
if is_type_enum(typ):
|
||||
enum_type: Type[enum.Enum] = typ
|
||||
enum_type: type[enum.Enum] = typ
|
||||
value_types = enum_value_types(enum_type)
|
||||
if len(value_types) != 1:
|
||||
raise ValueError(
|
||||
|
@ -496,8 +491,8 @@ class JsonSchemaGenerator:
|
|||
members = dict(inspect.getmembers(typ, lambda a: not inspect.isroutine(a)))
|
||||
|
||||
property_docstrings = get_class_property_docstrings(typ, self.options.property_description_fun)
|
||||
properties: Dict[str, Schema] = {}
|
||||
required: List[str] = []
|
||||
properties: dict[str, Schema] = {}
|
||||
required: list[str] = []
|
||||
for property_name, property_type in get_class_properties(typ):
|
||||
# rename property if an alias name is specified
|
||||
alias = get_annotation(property_type, Alias)
|
||||
|
@ -530,16 +525,7 @@ class JsonSchemaGenerator:
|
|||
# check if value can be directly represented in JSON
|
||||
if isinstance(
|
||||
def_value,
|
||||
(
|
||||
bool,
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
enum.Enum,
|
||||
datetime.datetime,
|
||||
datetime.date,
|
||||
datetime.time,
|
||||
),
|
||||
bool | int | float | str | enum.Enum | datetime.datetime | datetime.date | datetime.time,
|
||||
):
|
||||
property_def["default"] = object_to_json(def_value)
|
||||
|
||||
|
@ -587,7 +573,7 @@ class JsonSchemaGenerator:
|
|||
|
||||
return type_schema
|
||||
|
||||
def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> Tuple[Schema, Dict[str, Schema]]:
|
||||
def classdef_to_schema(self, data_type: TypeLike, force_expand: bool = False) -> tuple[Schema, dict[str, Schema]]:
|
||||
"""
|
||||
Returns the JSON schema associated with a type and any nested types.
|
||||
|
||||
|
@ -604,7 +590,7 @@ class JsonSchemaGenerator:
|
|||
try:
|
||||
type_schema = self.type_to_schema(data_type, force_expand=force_expand)
|
||||
|
||||
types_defined: Dict[str, Schema] = {}
|
||||
types_defined: dict[str, Schema] = {}
|
||||
while len(self.types_used) > len(types_defined):
|
||||
# make a snapshot copy; original collection is going to be modified
|
||||
types_undefined = {
|
||||
|
@ -635,7 +621,7 @@ class Validator(enum.Enum):
|
|||
|
||||
def classdef_to_schema(
|
||||
data_type: TypeLike,
|
||||
options: Optional[SchemaOptions] = None,
|
||||
options: SchemaOptions | None = None,
|
||||
validator: Validator = Validator.Latest,
|
||||
) -> Schema:
|
||||
"""
|
||||
|
@ -689,7 +675,7 @@ def print_schema(data_type: type) -> None:
|
|||
print(json.dumps(s, indent=4))
|
||||
|
||||
|
||||
def get_schema_identifier(data_type: type) -> Optional[str]:
|
||||
def get_schema_identifier(data_type: type) -> str | None:
|
||||
if data_type in JsonSchemaGenerator.type_catalog:
|
||||
return JsonSchemaGenerator.type_catalog.get(data_type).identifier
|
||||
else:
|
||||
|
@ -698,9 +684,9 @@ def get_schema_identifier(data_type: type) -> Optional[str]:
|
|||
|
||||
def register_schema(
|
||||
data_type: T,
|
||||
schema: Optional[Schema] = None,
|
||||
name: Optional[str] = None,
|
||||
examples: Optional[List[JsonType]] = None,
|
||||
schema: Schema | None = None,
|
||||
name: str | None = None,
|
||||
examples: list[JsonType] | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Associates a type with a JSON schema definition.
|
||||
|
@ -721,22 +707,22 @@ def register_schema(
|
|||
|
||||
|
||||
@overload
|
||||
def json_schema_type(cls: Type[T], /) -> Type[T]: ...
|
||||
def json_schema_type(cls: type[T], /) -> type[T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def json_schema_type(cls: None, *, schema: Optional[Schema] = None) -> Callable[[Type[T]], Type[T]]: ...
|
||||
def json_schema_type(cls: None, *, schema: Schema | None = None) -> Callable[[type[T]], type[T]]: ...
|
||||
|
||||
|
||||
def json_schema_type(
|
||||
cls: Optional[Type[T]] = None,
|
||||
cls: type[T] | None = None,
|
||||
*,
|
||||
schema: Optional[Schema] = None,
|
||||
examples: Optional[List[JsonType]] = None,
|
||||
) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
|
||||
schema: Schema | None = None,
|
||||
examples: list[JsonType] | None = None,
|
||||
) -> type[T] | Callable[[type[T]], type[T]]:
|
||||
"""Decorator to add user-defined schema definition to a class."""
|
||||
|
||||
def wrap(cls: Type[T]) -> Type[T]:
|
||||
def wrap(cls: type[T]) -> type[T]:
|
||||
return register_schema(cls, schema, examples=examples)
|
||||
|
||||
# see if decorator is used as @json_schema_type or @json_schema_type()
|
||||
|
|
|
@ -14,7 +14,7 @@ import inspect
|
|||
import json
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any, Optional, TextIO, TypeVar
|
||||
from typing import Any, TextIO, TypeVar
|
||||
|
||||
from .core import JsonType
|
||||
from .deserializer import create_deserializer
|
||||
|
@ -42,7 +42,7 @@ def object_to_json(obj: Any) -> JsonType:
|
|||
return generator.generate(obj)
|
||||
|
||||
|
||||
def json_to_object(typ: TypeLike, data: JsonType, *, context: Optional[ModuleType] = None) -> object:
|
||||
def json_to_object(typ: TypeLike, data: JsonType, *, context: ModuleType | None = None) -> object:
|
||||
"""
|
||||
Creates an object from a representation that has been de-serialized from JSON.
|
||||
|
||||
|
|
|
@ -20,19 +20,13 @@ import ipaddress
|
|||
import sys
|
||||
import typing
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from types import FunctionType, MethodType, ModuleType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
@ -133,7 +127,7 @@ class IPv6Serializer(Serializer[ipaddress.IPv6Address]):
|
|||
|
||||
|
||||
class EnumSerializer(Serializer[enum.Enum]):
|
||||
def generate(self, obj: enum.Enum) -> Union[int, str]:
|
||||
def generate(self, obj: enum.Enum) -> int | str:
|
||||
value = obj.value
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
|
@ -141,12 +135,12 @@ class EnumSerializer(Serializer[enum.Enum]):
|
|||
|
||||
|
||||
class UntypedListSerializer(Serializer[list]):
|
||||
def generate(self, obj: list) -> List[JsonType]:
|
||||
def generate(self, obj: list) -> list[JsonType]:
|
||||
return [object_to_json(item) for item in obj]
|
||||
|
||||
|
||||
class UntypedDictSerializer(Serializer[dict]):
|
||||
def generate(self, obj: dict) -> Dict[str, JsonType]:
|
||||
def generate(self, obj: dict) -> dict[str, JsonType]:
|
||||
if obj and isinstance(next(iter(obj.keys())), enum.Enum):
|
||||
iterator = ((key.value, object_to_json(value)) for key, value in obj.items())
|
||||
else:
|
||||
|
@ -155,41 +149,41 @@ class UntypedDictSerializer(Serializer[dict]):
|
|||
|
||||
|
||||
class UntypedSetSerializer(Serializer[set]):
|
||||
def generate(self, obj: set) -> List[JsonType]:
|
||||
def generate(self, obj: set) -> list[JsonType]:
|
||||
return [object_to_json(item) for item in obj]
|
||||
|
||||
|
||||
class UntypedTupleSerializer(Serializer[tuple]):
|
||||
def generate(self, obj: tuple) -> List[JsonType]:
|
||||
def generate(self, obj: tuple) -> list[JsonType]:
|
||||
return [object_to_json(item) for item in obj]
|
||||
|
||||
|
||||
class TypedCollectionSerializer(Serializer, Generic[T]):
|
||||
generator: Serializer[T]
|
||||
|
||||
def __init__(self, item_type: Type[T], context: Optional[ModuleType]) -> None:
|
||||
def __init__(self, item_type: type[T], context: ModuleType | None) -> None:
|
||||
self.generator = _get_serializer(item_type, context)
|
||||
|
||||
|
||||
class TypedListSerializer(TypedCollectionSerializer[T]):
|
||||
def generate(self, obj: List[T]) -> List[JsonType]:
|
||||
def generate(self, obj: list[T]) -> list[JsonType]:
|
||||
return [self.generator.generate(item) for item in obj]
|
||||
|
||||
|
||||
class TypedStringDictSerializer(TypedCollectionSerializer[T]):
|
||||
def __init__(self, value_type: Type[T], context: Optional[ModuleType]) -> None:
|
||||
def __init__(self, value_type: type[T], context: ModuleType | None) -> None:
|
||||
super().__init__(value_type, context)
|
||||
|
||||
def generate(self, obj: Dict[str, T]) -> Dict[str, JsonType]:
|
||||
def generate(self, obj: dict[str, T]) -> dict[str, JsonType]:
|
||||
return {key: self.generator.generate(value) for key, value in obj.items()}
|
||||
|
||||
|
||||
class TypedEnumDictSerializer(TypedCollectionSerializer[T]):
|
||||
def __init__(
|
||||
self,
|
||||
key_type: Type[enum.Enum],
|
||||
value_type: Type[T],
|
||||
context: Optional[ModuleType],
|
||||
key_type: type[enum.Enum],
|
||||
value_type: type[T],
|
||||
context: ModuleType | None,
|
||||
) -> None:
|
||||
super().__init__(value_type, context)
|
||||
|
||||
|
@ -203,22 +197,22 @@ class TypedEnumDictSerializer(TypedCollectionSerializer[T]):
|
|||
if value_type is not str:
|
||||
raise JsonTypeError("invalid enumeration key type, expected `enum.Enum` with string values")
|
||||
|
||||
def generate(self, obj: Dict[enum.Enum, T]) -> Dict[str, JsonType]:
|
||||
def generate(self, obj: dict[enum.Enum, T]) -> dict[str, JsonType]:
|
||||
return {key.value: self.generator.generate(value) for key, value in obj.items()}
|
||||
|
||||
|
||||
class TypedSetSerializer(TypedCollectionSerializer[T]):
|
||||
def generate(self, obj: Set[T]) -> JsonType:
|
||||
def generate(self, obj: set[T]) -> JsonType:
|
||||
return [self.generator.generate(item) for item in obj]
|
||||
|
||||
|
||||
class TypedTupleSerializer(Serializer[tuple]):
|
||||
item_generators: Tuple[Serializer, ...]
|
||||
item_generators: tuple[Serializer, ...]
|
||||
|
||||
def __init__(self, item_types: Tuple[type, ...], context: Optional[ModuleType]) -> None:
|
||||
def __init__(self, item_types: tuple[type, ...], context: ModuleType | None) -> None:
|
||||
self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types)
|
||||
|
||||
def generate(self, obj: tuple) -> List[JsonType]:
|
||||
def generate(self, obj: tuple) -> list[JsonType]:
|
||||
return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj, strict=False)]
|
||||
|
||||
|
||||
|
@ -250,16 +244,16 @@ class FieldSerializer(Generic[T]):
|
|||
self.property_name = property_name
|
||||
self.generator = generator
|
||||
|
||||
def generate_field(self, obj: object, object_dict: Dict[str, JsonType]) -> None:
|
||||
def generate_field(self, obj: object, object_dict: dict[str, JsonType]) -> None:
|
||||
value = getattr(obj, self.field_name)
|
||||
if value is not None:
|
||||
object_dict[self.property_name] = self.generator.generate(value)
|
||||
|
||||
|
||||
class TypedClassSerializer(Serializer[T]):
|
||||
property_generators: List[FieldSerializer]
|
||||
property_generators: list[FieldSerializer]
|
||||
|
||||
def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None:
|
||||
def __init__(self, class_type: type[T], context: ModuleType | None) -> None:
|
||||
self.property_generators = [
|
||||
FieldSerializer(
|
||||
field_name,
|
||||
|
@ -269,8 +263,8 @@ class TypedClassSerializer(Serializer[T]):
|
|||
for field_name, field_type in get_class_properties(class_type)
|
||||
]
|
||||
|
||||
def generate(self, obj: T) -> Dict[str, JsonType]:
|
||||
object_dict: Dict[str, JsonType] = {}
|
||||
def generate(self, obj: T) -> dict[str, JsonType]:
|
||||
object_dict: dict[str, JsonType] = {}
|
||||
for property_generator in self.property_generators:
|
||||
property_generator.generate_field(obj, object_dict)
|
||||
|
||||
|
@ -278,12 +272,12 @@ class TypedClassSerializer(Serializer[T]):
|
|||
|
||||
|
||||
class TypedNamedTupleSerializer(TypedClassSerializer[NamedTuple]):
|
||||
def __init__(self, class_type: Type[NamedTuple], context: Optional[ModuleType]) -> None:
|
||||
def __init__(self, class_type: type[NamedTuple], context: ModuleType | None) -> None:
|
||||
super().__init__(class_type, context)
|
||||
|
||||
|
||||
class DataclassSerializer(TypedClassSerializer[T]):
|
||||
def __init__(self, class_type: Type[T], context: Optional[ModuleType]) -> None:
|
||||
def __init__(self, class_type: type[T], context: ModuleType | None) -> None:
|
||||
super().__init__(class_type, context)
|
||||
|
||||
|
||||
|
@ -295,7 +289,7 @@ class UnionSerializer(Serializer):
|
|||
class LiteralSerializer(Serializer):
|
||||
generator: Serializer
|
||||
|
||||
def __init__(self, values: Tuple[Any, ...], context: Optional[ModuleType]) -> None:
|
||||
def __init__(self, values: tuple[Any, ...], context: ModuleType | None) -> None:
|
||||
literal_type_tuple = tuple(type(value) for value in values)
|
||||
literal_type_set = set(literal_type_tuple)
|
||||
if len(literal_type_set) != 1:
|
||||
|
@ -312,12 +306,12 @@ class LiteralSerializer(Serializer):
|
|||
|
||||
|
||||
class UntypedNamedTupleSerializer(Serializer):
|
||||
fields: Dict[str, str]
|
||||
fields: dict[str, str]
|
||||
|
||||
def __init__(self, class_type: Type[NamedTuple]) -> None:
|
||||
def __init__(self, class_type: type[NamedTuple]) -> None:
|
||||
# named tuples are also instances of tuple
|
||||
self.fields = {}
|
||||
field_names: Tuple[str, ...] = class_type._fields
|
||||
field_names: tuple[str, ...] = class_type._fields
|
||||
for field_name in field_names:
|
||||
self.fields[field_name] = python_field_to_json_property(field_name)
|
||||
|
||||
|
@ -351,7 +345,7 @@ class UntypedClassSerializer(Serializer):
|
|||
return object_dict
|
||||
|
||||
|
||||
def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Serializer:
|
||||
def create_serializer(typ: TypeLike, context: ModuleType | None = None) -> Serializer:
|
||||
"""
|
||||
Creates a serializer engine to produce an object that can be directly converted into a JSON string.
|
||||
|
||||
|
@ -376,8 +370,8 @@ def create_serializer(typ: TypeLike, context: Optional[ModuleType] = None) -> Se
|
|||
return _get_serializer(typ, context)
|
||||
|
||||
|
||||
def _get_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
|
||||
if isinstance(typ, (str, typing.ForwardRef)):
|
||||
def _get_serializer(typ: TypeLike, context: ModuleType | None) -> Serializer:
|
||||
if isinstance(typ, str | typing.ForwardRef):
|
||||
if context is None:
|
||||
raise TypeError(f"missing context for evaluating type: {typ}")
|
||||
|
||||
|
@ -390,13 +384,13 @@ def _get_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
|
|||
return _create_serializer(typ, context)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def _fetch_serializer(typ: type) -> Serializer:
|
||||
context = sys.modules[typ.__module__]
|
||||
return _create_serializer(typ, context)
|
||||
|
||||
|
||||
def _create_serializer(typ: TypeLike, context: Optional[ModuleType]) -> Serializer:
|
||||
def _create_serializer(typ: TypeLike, context: ModuleType | None) -> Serializer:
|
||||
# check for well-known types
|
||||
if typ is type(None):
|
||||
return NoneSerializer()
|
||||
|
|
|
@ -4,18 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Tuple, Type, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SlotsMeta(type):
|
||||
def __new__(cls: Type[T], name: str, bases: Tuple[type, ...], ns: Dict[str, Any]) -> T:
|
||||
def __new__(cls: type[T], name: str, bases: tuple[type, ...], ns: dict[str, Any]) -> T:
|
||||
# caller may have already provided slots, in which case just retain them and keep going
|
||||
slots: Tuple[str, ...] = ns.get("__slots__", ())
|
||||
slots: tuple[str, ...] = ns.get("__slots__", ())
|
||||
|
||||
# add fields with type annotations to slots
|
||||
annotations: Dict[str, Any] = ns.get("__annotations__", {})
|
||||
annotations: dict[str, Any] = ns.get("__annotations__", {})
|
||||
members = tuple(member for member in annotations.keys() if member not in slots)
|
||||
|
||||
# assign slots
|
||||
|
|
|
@ -10,14 +10,15 @@ Type-safe data interchange for Python data classes.
|
|||
:see: https://github.com/hunyadi/strong_typing
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Set, TypeVar
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TypeVar
|
||||
|
||||
from .inspection import TypeCollector
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def topological_sort(graph: Dict[T, Set[T]]) -> List[T]:
|
||||
def topological_sort(graph: dict[T, set[T]]) -> list[T]:
|
||||
"""
|
||||
Performs a topological sort of a graph.
|
||||
|
||||
|
@ -29,9 +30,9 @@ def topological_sort(graph: Dict[T, Set[T]]) -> List[T]:
|
|||
"""
|
||||
|
||||
# empty list that will contain the sorted nodes (in reverse order)
|
||||
ordered: List[T] = []
|
||||
ordered: list[T] = []
|
||||
|
||||
seen: Dict[T, bool] = {}
|
||||
seen: dict[T, bool] = {}
|
||||
|
||||
def _visit(n: T) -> None:
|
||||
status = seen.get(n)
|
||||
|
@ -57,8 +58,8 @@ def topological_sort(graph: Dict[T, Set[T]]) -> List[T]:
|
|||
|
||||
def type_topological_sort(
|
||||
types: Iterable[type],
|
||||
dependency_fn: Optional[Callable[[type], Iterable[type]]] = None,
|
||||
) -> List[type]:
|
||||
dependency_fn: Callable[[type], Iterable[type]] | None = None,
|
||||
) -> list[type]:
|
||||
"""
|
||||
Performs a topological sort of a list of types.
|
||||
|
||||
|
@ -78,7 +79,7 @@ def type_topological_sort(
|
|||
graph = collector.graph
|
||||
|
||||
if dependency_fn:
|
||||
new_types: Set[type] = set()
|
||||
new_types: set[type] = set()
|
||||
for source_type, references in graph.items():
|
||||
dependent_types = dependency_fn(source_type)
|
||||
references.update(dependent_types)
|
||||
|
|
7
llama_stack/templates/ollama/__init__.py
Normal file
7
llama_stack/templates/ollama/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .ollama import get_distribution_template # noqa: F401
|
39
llama_stack/templates/ollama/build.yaml
Normal file
39
llama_stack/templates/ollama/build.yaml
Normal file
|
@ -0,0 +1,39 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: Use (an external) Ollama server for running LLM inference
|
||||
providers:
|
||||
inference:
|
||||
- remote::ollama
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
datasetio:
|
||||
- remote::huggingface
|
||||
- inline::localfs
|
||||
scoring:
|
||||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
files:
|
||||
- inline::localfs
|
||||
post_training:
|
||||
- inline::huggingface
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
168
llama_stack/templates/ollama/doc_template.md
Normal file
168
llama_stack/templates/ollama/doc_template.md
Normal file
|
@ -0,0 +1,168 @@
|
|||
# Ollama Distribution
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||
|
||||
{{ providers_table }}
|
||||
|
||||
{% if run_config_env_vars %}
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if default_models %}
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
{% for model in default_models %}
|
||||
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Ollama Server
|
||||
|
||||
This distribution requires an external Ollama server to be running. You can install and run Ollama by following these steps:
|
||||
|
||||
1. **Install Ollama**: Download and install Ollama from [https://ollama.ai/](https://ollama.ai/)
|
||||
|
||||
2. **Start the Ollama server**:
|
||||
```bash
|
||||
ollama serve
|
||||
```
|
||||
By default, Ollama serves on `http://127.0.0.1:11434`
|
||||
|
||||
3. **Pull the required models**:
|
||||
```bash
|
||||
# Pull the inference model
|
||||
ollama pull meta-llama/Llama-3.2-3B-Instruct
|
||||
|
||||
# Pull the embedding model
|
||||
ollama pull all-minilm:latest
|
||||
|
||||
# (Optional) Pull the safety model for run-with-safety.yaml
|
||||
ollama pull meta-llama/Llama-Guard-3-1B
|
||||
```
|
||||
|
||||
## Supported Services
|
||||
|
||||
### Inference: Ollama
|
||||
Uses an external Ollama server for running LLM inference. The server should be accessible at the URL specified in the `OLLAMA_URL` environment variable.
|
||||
|
||||
### Vector IO: FAISS
|
||||
Provides vector storage capabilities using FAISS for embeddings and similarity search operations.
|
||||
|
||||
### Safety: Llama Guard (Optional)
|
||||
When using the `run-with-safety.yaml` configuration, provides safety checks using Llama Guard models running on the Ollama server.
|
||||
|
||||
### Agents: Meta Reference
|
||||
Provides agent execution capabilities using the meta-reference implementation.
|
||||
|
||||
### Post-Training: Hugging Face
|
||||
Supports model fine-tuning using Hugging Face integration.
|
||||
|
||||
### Tool Runtime
|
||||
Supports various external tools including:
|
||||
- Brave Search
|
||||
- Tavily Search
|
||||
- RAG Runtime
|
||||
- Model Context Protocol
|
||||
- Wolfram Alpha
|
||||
|
||||
## Running Llama Stack with Ollama
|
||||
|
||||
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||
|
||||
### Via Docker
|
||||
|
||||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env OLLAMA_URL=$OLLAMA_URL \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
||||
```bash
|
||||
llama stack build --template ollama --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env OLLAMA_URL=$OLLAMA_URL \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
llama stack build --template ollama --image-type venv
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env OLLAMA_URL=$OLLAMA_URL \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
### Running with Safety
|
||||
|
||||
To enable safety checks, use the `run-with-safety.yaml` configuration:
|
||||
|
||||
```bash
|
||||
llama stack run ./run-with-safety.yaml \
|
||||
--port 8321 \
|
||||
--env OLLAMA_URL=$OLLAMA_URL \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL
|
||||
```
|
||||
|
||||
## Example Usage
|
||||
|
||||
Once your Llama Stack server is running with Ollama, you can interact with it using the Llama Stack client:
|
||||
|
||||
```python
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||
|
||||
# Run inference
|
||||
response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.2-3B-Instruct",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
)
|
||||
print(response.completion_message.content)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Connection refused errors**: Ensure your Ollama server is running and accessible at the configured URL.
|
||||
|
||||
2. **Model not found errors**: Make sure you've pulled the required models using `ollama pull <model-name>`.
|
||||
|
||||
3. **Performance issues**: Consider using more powerful models or adjusting the Ollama server configuration for better performance.
|
||||
|
||||
### Logs
|
||||
|
||||
Check the Ollama server logs for any issues:
|
||||
```bash
|
||||
# Ollama logs are typically available in:
|
||||
# - macOS: ~/Library/Logs/Ollama/
|
||||
# - Linux: ~/.ollama/logs/
|
||||
```
|
180
llama_stack/templates/ollama/ollama.py
Normal file
180
llama_stack/templates/ollama/ollama.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::ollama"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"files": ["inline::localfs"],
|
||||
"post_training": ["inline::huggingface"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
"remote::wolfram-alpha",
|
||||
],
|
||||
}
|
||||
name = "ollama"
|
||||
inference_provider = Provider(
|
||||
provider_id="ollama",
|
||||
provider_type="remote::ollama",
|
||||
config=OllamaImplConfig.sample_run_config(),
|
||||
)
|
||||
vector_io_provider_faiss = Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(
|
||||
f"${{env.XDG_STATE_HOME:-~/.local/state}}/llama-stack/distributions/{name}"
|
||||
),
|
||||
)
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(
|
||||
f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/distributions/{name}"
|
||||
),
|
||||
)
|
||||
posttraining_provider = Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="inline::huggingface",
|
||||
config=HuggingFacePostTrainingConfig.sample_run_config(
|
||||
f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/distributions/{name}"
|
||||
),
|
||||
)
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="ollama",
|
||||
)
|
||||
safety_model = ModelInput(
|
||||
model_id="${env.SAFETY_MODEL}",
|
||||
provider_id="ollama",
|
||||
)
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="ollama",
|
||||
provider_model_id="all-minilm:latest",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::wolfram_alpha",
|
||||
provider_id="wolfram-alpha",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Use (an external) Ollama server for running LLM inference",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
"files": [files_provider],
|
||||
"post_training": [posttraining_provider],
|
||||
},
|
||||
default_models=[inference_model, embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
"run-with-safety.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider_faiss],
|
||||
"files": [files_provider],
|
||||
"post_training": [posttraining_provider],
|
||||
"safety": [
|
||||
Provider(
|
||||
provider_id="llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
config={},
|
||||
),
|
||||
Provider(
|
||||
provider_id="code-scanner",
|
||||
provider_type="inline::code-scanner",
|
||||
config={},
|
||||
),
|
||||
],
|
||||
},
|
||||
default_models=[
|
||||
inference_model,
|
||||
safety_model,
|
||||
embedding_model,
|
||||
],
|
||||
default_shields=[
|
||||
ShieldInput(
|
||||
shield_id="${env.SAFETY_MODEL}",
|
||||
provider_id="llama-guard",
|
||||
),
|
||||
ShieldInput(
|
||||
shield_id="CodeScanner",
|
||||
provider_id="code-scanner",
|
||||
),
|
||||
],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"OLLAMA_URL": (
|
||||
"http://127.0.0.1:11434",
|
||||
"URL of the Ollama server",
|
||||
),
|
||||
"INFERENCE_MODEL": (
|
||||
"meta-llama/Llama-3.2-3B-Instruct",
|
||||
"Inference model loaded into the Ollama server",
|
||||
),
|
||||
"SAFETY_MODEL": (
|
||||
"meta-llama/Llama-Guard-3-1B",
|
||||
"Safety model loaded into the Ollama server",
|
||||
),
|
||||
},
|
||||
)
|
158
llama_stack/templates/ollama/run-with-safety.yaml
Normal file
158
llama_stack/templates/ollama/run-with-safety.yaml
Normal file
|
@ -0,0 +1,158 @@
|
|||
version: 2
|
||||
image_name: ollama
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
- provider_id: code-scanner
|
||||
provider_type: inline::code-scanner
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama}/files_metadata.db
|
||||
post_training:
|
||||
- provider_id: huggingface
|
||||
provider_type: inline::huggingface
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:=}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: ollama
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.SAFETY_MODEL}
|
||||
provider_id: ollama
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: ollama
|
||||
provider_model_id: all-minilm:latest
|
||||
model_type: embedding
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL}
|
||||
provider_id: llama-guard
|
||||
- shield_id: CodeScanner
|
||||
provider_id: code-scanner
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
148
llama_stack/templates/ollama/run.yaml
Normal file
148
llama_stack/templates/ollama/run.yaml
Normal file
|
@ -0,0 +1,148 @@
|
|||
version: 2
|
||||
image_name: ollama
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama}/files_metadata.db
|
||||
post_training:
|
||||
- provider_id: huggingface
|
||||
provider_type: inline::huggingface
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:=}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: ollama
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: ollama
|
||||
provider_model_id: all-minilm:latest
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
|
@ -70,11 +70,44 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|||
default_value = field.default_factory()
|
||||
# HACK ALERT:
|
||||
# If the default value contains a path that looks like it came from RUNTIME_BASE_DIR,
|
||||
# replace it with a generic ~/.llama/ path for documentation
|
||||
if isinstance(default_value, str) and "/.llama/" in default_value:
|
||||
if ".llama/" in default_value:
|
||||
# replace it with a generic XDG-compliant path for documentation
|
||||
if isinstance(default_value, str):
|
||||
# Handle legacy .llama/ paths
|
||||
if "/.llama/" in default_value:
|
||||
path_part = default_value.split(".llama/")[-1]
|
||||
default_value = f"~/.llama/{path_part}"
|
||||
# Use appropriate XDG directory based on path content
|
||||
if path_part.startswith("runtime/"):
|
||||
default_value = f"${{env.XDG_STATE_HOME:-~/.local/state}}/llama-stack/{path_part}"
|
||||
else:
|
||||
default_value = f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/{path_part}"
|
||||
# Handle XDG state paths (runtime data)
|
||||
elif "/llama-stack/runtime/" in default_value:
|
||||
path_part = default_value.split("/llama-stack/runtime/")[-1]
|
||||
default_value = (
|
||||
f"${{env.XDG_STATE_HOME:-~/.local/state}}/llama-stack/runtime/{path_part}"
|
||||
)
|
||||
# Handle XDG data paths
|
||||
elif "/llama-stack/data/" in default_value or "/llama-stack/checkpoints/" in default_value:
|
||||
if "/llama-stack/data/" in default_value:
|
||||
path_part = default_value.split("/llama-stack/data/")[-1]
|
||||
default_value = (
|
||||
f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/data/{path_part}"
|
||||
)
|
||||
else:
|
||||
path_part = default_value.split("/llama-stack/checkpoints/")[-1]
|
||||
default_value = (
|
||||
f"${{env.XDG_DATA_HOME:-~/.local/share}}/llama-stack/checkpoints/{path_part}"
|
||||
)
|
||||
# Handle XDG config paths
|
||||
elif "/llama-stack/" in default_value and (
|
||||
"/config/" in default_value or "/distributions/" in default_value
|
||||
):
|
||||
if "/config/" in default_value:
|
||||
path_part = default_value.split("/llama-stack/")[-1]
|
||||
default_value = f"${{env.XDG_CONFIG_HOME:-~/.config}}/llama-stack/{path_part}"
|
||||
else:
|
||||
path_part = default_value.split("/llama-stack/")[-1]
|
||||
default_value = f"${{env.XDG_CONFIG_HOME:-~/.config}}/llama-stack/{path_part}"
|
||||
except Exception:
|
||||
default_value = ""
|
||||
elif field.default is None:
|
||||
|
@ -201,7 +234,9 @@ def generate_provider_docs(provider_spec: Any, api_name: str) -> str:
|
|||
if sample_config_func is not None:
|
||||
sig = inspect.signature(sample_config_func)
|
||||
if "__distro_dir__" in sig.parameters:
|
||||
sample_config = sample_config_func(__distro_dir__="~/.llama/dummy")
|
||||
sample_config = sample_config_func(
|
||||
__distro_dir__="${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/dummy"
|
||||
)
|
||||
else:
|
||||
sample_config = sample_config_func()
|
||||
|
||||
|
|
593
tests/integration/test_xdg_e2e.py
Normal file
593
tests/integration/test_xdg_e2e.py
Normal file
|
@ -0,0 +1,593 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.utils.xdg_utils import (
|
||||
get_llama_stack_config_dir,
|
||||
get_llama_stack_data_dir,
|
||||
get_llama_stack_state_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestXDGEndToEnd(unittest.TestCase):
|
||||
"""End-to-end tests for XDG compliance workflows."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.original_env = {}
|
||||
self.env_vars = [
|
||||
"XDG_CONFIG_HOME",
|
||||
"XDG_DATA_HOME",
|
||||
"XDG_STATE_HOME",
|
||||
"XDG_CACHE_HOME",
|
||||
"LLAMA_STACK_CONFIG_DIR",
|
||||
"SQLITE_STORE_DIR",
|
||||
"FILES_STORAGE_DIR",
|
||||
]
|
||||
|
||||
for key in self.env_vars:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def clear_env_vars(self):
|
||||
"""Clear all relevant environment variables."""
|
||||
for key in self.env_vars:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def create_realistic_legacy_structure(self, base_dir: Path) -> Path:
|
||||
"""Create a realistic legacy ~/.llama directory structure."""
|
||||
legacy_dir = base_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
|
||||
# Create distributions with realistic content
|
||||
distributions_dir = legacy_dir / "distributions"
|
||||
distributions_dir.mkdir()
|
||||
|
||||
# Ollama distribution
|
||||
ollama_dir = distributions_dir / "ollama"
|
||||
ollama_dir.mkdir()
|
||||
|
||||
ollama_run_yaml = ollama_dir / "ollama-run.yaml"
|
||||
ollama_run_yaml.write_text("""
|
||||
version: 2
|
||||
apis:
|
||||
- inference
|
||||
- safety
|
||||
- memory
|
||||
- vector_io
|
||||
- agents
|
||||
- files
|
||||
providers:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
url: http://localhost:11434
|
||||
""")
|
||||
|
||||
ollama_build_yaml = ollama_dir / "build.yaml"
|
||||
ollama_build_yaml.write_text("""
|
||||
name: ollama
|
||||
description: Ollama inference provider
|
||||
docker_image: ollama:latest
|
||||
""")
|
||||
|
||||
# Create providers.d structure
|
||||
providers_dir = legacy_dir / "providers.d"
|
||||
providers_dir.mkdir()
|
||||
|
||||
remote_dir = providers_dir / "remote"
|
||||
remote_dir.mkdir()
|
||||
|
||||
inference_dir = remote_dir / "inference"
|
||||
inference_dir.mkdir()
|
||||
|
||||
custom_provider = inference_dir / "custom-inference.yaml"
|
||||
custom_provider.write_text("""
|
||||
provider_type: remote::custom
|
||||
config:
|
||||
url: http://localhost:8080
|
||||
api_key: test_key
|
||||
""")
|
||||
|
||||
# Create checkpoints with model files
|
||||
checkpoints_dir = legacy_dir / "checkpoints"
|
||||
checkpoints_dir.mkdir()
|
||||
|
||||
model_dir = checkpoints_dir / "meta-llama" / "Llama-3.2-1B-Instruct"
|
||||
model_dir.mkdir(parents=True)
|
||||
|
||||
# Create fake model files
|
||||
(model_dir / "consolidated.00.pth").write_bytes(b"fake model weights" * 1000)
|
||||
(model_dir / "params.json").write_text('{"dim": 2048, "n_layers": 22}')
|
||||
(model_dir / "tokenizer.model").write_bytes(b"fake tokenizer" * 100)
|
||||
|
||||
# Create runtime with databases
|
||||
runtime_dir = legacy_dir / "runtime"
|
||||
runtime_dir.mkdir()
|
||||
|
||||
(runtime_dir / "trace_store.db").write_text("SQLite format 3\x00" + "fake database content")
|
||||
(runtime_dir / "agent_sessions.db").write_text("SQLite format 3\x00" + "fake agent sessions")
|
||||
|
||||
# Create config files
|
||||
(legacy_dir / "config.json").write_text('{"version": "0.2.13", "last_updated": "2024-01-01"}')
|
||||
|
||||
return legacy_dir
|
||||
|
||||
def verify_xdg_migration_complete(self, base_dir: Path, legacy_dir: Path):
|
||||
"""Verify that migration to XDG structure is complete and correct."""
|
||||
config_dir = base_dir / ".config" / "llama-stack"
|
||||
data_dir = base_dir / ".local" / "share" / "llama-stack"
|
||||
state_dir = base_dir / ".local" / "state" / "llama-stack"
|
||||
|
||||
# Verify distributions moved to config
|
||||
self.assertTrue((config_dir / "distributions").exists())
|
||||
self.assertTrue((config_dir / "distributions" / "ollama").exists())
|
||||
self.assertTrue((config_dir / "distributions" / "ollama" / "ollama-run.yaml").exists())
|
||||
|
||||
# Verify YAML content is preserved
|
||||
yaml_content = (config_dir / "distributions" / "ollama" / "ollama-run.yaml").read_text()
|
||||
self.assertIn("version: 2", yaml_content)
|
||||
self.assertIn("remote::ollama", yaml_content)
|
||||
|
||||
# Verify providers.d moved to config
|
||||
self.assertTrue((config_dir / "providers.d").exists())
|
||||
self.assertTrue((config_dir / "providers.d" / "remote" / "inference").exists())
|
||||
self.assertTrue((config_dir / "providers.d" / "remote" / "inference" / "custom-inference.yaml").exists())
|
||||
|
||||
# Verify checkpoints moved to data
|
||||
self.assertTrue((data_dir / "checkpoints").exists())
|
||||
self.assertTrue((data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct").exists())
|
||||
self.assertTrue(
|
||||
(data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct" / "consolidated.00.pth").exists()
|
||||
)
|
||||
|
||||
# Verify model file content preserved
|
||||
model_file = data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct" / "consolidated.00.pth"
|
||||
self.assertGreater(model_file.stat().st_size, 1000) # Should be substantial size
|
||||
|
||||
# Verify runtime moved to state
|
||||
self.assertTrue((state_dir / "runtime").exists())
|
||||
self.assertTrue((state_dir / "runtime" / "trace_store.db").exists())
|
||||
self.assertTrue((state_dir / "runtime" / "agent_sessions.db").exists())
|
||||
|
||||
# Verify database files preserved
|
||||
db_file = state_dir / "runtime" / "trace_store.db"
|
||||
db_content = db_file.read_text()
|
||||
self.assertIn("SQLite format 3", db_content)
|
||||
|
||||
def test_fresh_installation_xdg_compliance(self):
|
||||
"""Test fresh installation uses XDG-compliant paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
# Set custom XDG paths
|
||||
os.environ["XDG_CONFIG_HOME"] = str(base_dir / "custom_config")
|
||||
os.environ["XDG_DATA_HOME"] = str(base_dir / "custom_data")
|
||||
os.environ["XDG_STATE_HOME"] = str(base_dir / "custom_state")
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
# Mock that no legacy directory exists
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
# Fresh installation should use XDG paths
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
data_dir = get_llama_stack_data_dir()
|
||||
state_dir = get_llama_stack_state_dir()
|
||||
|
||||
self.assertEqual(config_dir, base_dir / "custom_config" / "llama-stack")
|
||||
self.assertEqual(data_dir, base_dir / "custom_data" / "llama-stack")
|
||||
self.assertEqual(state_dir, base_dir / "custom_state" / "llama-stack")
|
||||
|
||||
def test_complete_migration_workflow(self):
|
||||
"""Test complete migration workflow from legacy to XDG."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create realistic legacy structure
|
||||
legacy_dir = self.create_realistic_legacy_structure(base_dir)
|
||||
|
||||
# Verify legacy structure exists
|
||||
self.assertTrue(legacy_dir.exists())
|
||||
self.assertTrue((legacy_dir / "distributions" / "ollama" / "ollama-run.yaml").exists())
|
||||
|
||||
# Perform migration
|
||||
from llama_stack.cli.migrate_xdg import migrate_to_xdg
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify migration completed successfully
|
||||
self.verify_xdg_migration_complete(base_dir, legacy_dir)
|
||||
|
||||
# Verify legacy directory was removed
|
||||
self.assertFalse(legacy_dir.exists())
|
||||
|
||||
def test_migration_preserves_file_integrity(self):
|
||||
"""Test that migration preserves file integrity and permissions."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure
|
||||
legacy_dir = self.create_realistic_legacy_structure(base_dir)
|
||||
|
||||
# Set specific permissions on files
|
||||
config_file = legacy_dir / "distributions" / "ollama" / "ollama-run.yaml"
|
||||
config_file.chmod(0o600)
|
||||
original_stat = config_file.stat()
|
||||
|
||||
# Perform migration
|
||||
from llama_stack.cli.migrate_xdg import migrate_to_xdg
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"]
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify file integrity
|
||||
migrated_config = base_dir / ".config" / "llama-stack" / "distributions" / "ollama" / "ollama-run.yaml"
|
||||
self.assertTrue(migrated_config.exists())
|
||||
|
||||
# Verify content is identical
|
||||
migrated_content = migrated_config.read_text()
|
||||
self.assertIn("version: 2", migrated_content)
|
||||
self.assertIn("remote::ollama", migrated_content)
|
||||
|
||||
# Verify permissions preserved
|
||||
migrated_stat = migrated_config.stat()
|
||||
self.assertEqual(original_stat.st_mode, migrated_stat.st_mode)
|
||||
|
||||
def test_mixed_legacy_and_xdg_environment(self):
|
||||
"""Test behavior in mixed legacy and XDG environment."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
# Set partial XDG environment
|
||||
os.environ["XDG_CONFIG_HOME"] = str(base_dir / "xdg_config")
|
||||
# Leave XDG_DATA_HOME and XDG_STATE_HOME unset
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy directory
|
||||
legacy_dir = self.create_realistic_legacy_structure(base_dir)
|
||||
|
||||
# Should use legacy directory since it exists
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
data_dir = get_llama_stack_data_dir()
|
||||
state_dir = get_llama_stack_state_dir()
|
||||
|
||||
self.assertEqual(config_dir, legacy_dir)
|
||||
self.assertEqual(data_dir, legacy_dir)
|
||||
self.assertEqual(state_dir, legacy_dir)
|
||||
|
||||
def test_template_rendering_with_xdg_paths(self):
|
||||
"""Test that templates render correctly with XDG paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
# Set XDG environment
|
||||
os.environ["XDG_STATE_HOME"] = str(base_dir / "state")
|
||||
os.environ["XDG_DATA_HOME"] = str(base_dir / "data")
|
||||
|
||||
# Mock shell environment variable expansion
|
||||
def mock_env_expand(template_string):
|
||||
"""Mock shell environment variable expansion."""
|
||||
result = template_string
|
||||
result = result.replace("${env.XDG_STATE_HOME:-~/.local/state}", str(base_dir / "state"))
|
||||
result = result.replace("${env.XDG_DATA_HOME:-~/.local/share}", str(base_dir / "data"))
|
||||
return result
|
||||
|
||||
# Test template patterns
|
||||
template_patterns = [
|
||||
"${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama",
|
||||
"${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama/files",
|
||||
]
|
||||
|
||||
expected_results = [
|
||||
str(base_dir / "state" / "llama-stack" / "distributions" / "ollama"),
|
||||
str(base_dir / "data" / "llama-stack" / "distributions" / "ollama" / "files"),
|
||||
]
|
||||
|
||||
for pattern, expected in zip(template_patterns, expected_results, strict=False):
|
||||
with self.subTest(pattern=pattern):
|
||||
expanded = mock_env_expand(pattern)
|
||||
self.assertEqual(expanded, expected)
|
||||
|
||||
def test_cli_integration_with_xdg_paths(self):
|
||||
"""Test CLI integration works correctly with XDG paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure
|
||||
legacy_dir = self.create_realistic_legacy_structure(base_dir)
|
||||
|
||||
# Test CLI migrate command
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.migrate_xdg import MigrateXDG
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
migrate_cmd = MigrateXDG.create(subparsers)
|
||||
|
||||
# Test dry-run
|
||||
args = parser.parse_args(["migrate-xdg", "--dry-run"])
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_cmd._run_migrate_xdg_cmd(args)
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
# Should print dry-run information
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Dry run mode" in call for call in print_calls))
|
||||
|
||||
# Legacy directory should still exist after dry-run
|
||||
self.assertTrue(legacy_dir.exists())
|
||||
|
||||
def test_config_dirs_integration_after_migration(self):
|
||||
"""Test that config_dirs works correctly after migration."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create and migrate legacy structure
|
||||
self.create_realistic_legacy_structure(base_dir)
|
||||
|
||||
from llama_stack.cli.migrate_xdg import migrate_to_xdg
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"]
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Clear module cache to ensure fresh import
|
||||
import sys
|
||||
|
||||
if "llama_stack.distribution.utils.config_dirs" in sys.modules:
|
||||
del sys.modules["llama_stack.distribution.utils.config_dirs"]
|
||||
|
||||
# Import config_dirs after migration
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
RUNTIME_BASE_DIR,
|
||||
)
|
||||
|
||||
# Should use XDG paths
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, base_dir / ".config" / "llama-stack")
|
||||
self.assertEqual(DEFAULT_CHECKPOINT_DIR, base_dir / ".local" / "share" / "llama-stack" / "checkpoints")
|
||||
self.assertEqual(RUNTIME_BASE_DIR, base_dir / ".local" / "state" / "llama-stack" / "runtime")
|
||||
self.assertEqual(DISTRIBS_BASE_DIR, base_dir / ".config" / "llama-stack" / "distributions")
|
||||
|
||||
def test_real_file_operations_with_xdg_paths(self):
|
||||
"""Test real file operations work correctly with XDG paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
# Set XDG environment
|
||||
os.environ["XDG_CONFIG_HOME"] = str(base_dir / "config")
|
||||
os.environ["XDG_DATA_HOME"] = str(base_dir / "data")
|
||||
os.environ["XDG_STATE_HOME"] = str(base_dir / "state")
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
# Get XDG paths
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
data_dir = get_llama_stack_data_dir()
|
||||
state_dir = get_llama_stack_state_dir()
|
||||
|
||||
# Create directories
|
||||
config_dir.mkdir(parents=True)
|
||||
data_dir.mkdir(parents=True)
|
||||
state_dir.mkdir(parents=True)
|
||||
|
||||
# Test writing configuration files
|
||||
config_file = config_dir / "test_config.yaml"
|
||||
config_data = {"version": "2", "test": True}
|
||||
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
# Test reading configuration files
|
||||
with open(config_file) as f:
|
||||
loaded_config = yaml.safe_load(f)
|
||||
|
||||
self.assertEqual(loaded_config, config_data)
|
||||
|
||||
# Test creating nested directory structure
|
||||
model_dir = data_dir / "checkpoints" / "meta-llama" / "test-model"
|
||||
model_dir.mkdir(parents=True)
|
||||
|
||||
# Test writing large files
|
||||
model_file = model_dir / "model.bin"
|
||||
test_data = b"test model data" * 1000
|
||||
model_file.write_bytes(test_data)
|
||||
|
||||
# Verify file integrity
|
||||
read_data = model_file.read_bytes()
|
||||
self.assertEqual(read_data, test_data)
|
||||
|
||||
# Test state files
|
||||
state_file = state_dir / "runtime" / "session.db"
|
||||
state_file.parent.mkdir(parents=True)
|
||||
state_file.write_text("SQLite format 3\x00test database")
|
||||
|
||||
# Verify state file
|
||||
state_content = state_file.read_text()
|
||||
self.assertIn("SQLite format 3", state_content)
|
||||
|
||||
def test_backwards_compatibility_scenario(self):
|
||||
"""Test complete backwards compatibility scenario."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
# Scenario: User has existing legacy installation
|
||||
legacy_dir = self.create_realistic_legacy_structure(base_dir)
|
||||
|
||||
# User sets legacy environment variable
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = str(legacy_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Should continue using legacy paths
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
data_dir = get_llama_stack_data_dir()
|
||||
state_dir = get_llama_stack_state_dir()
|
||||
|
||||
self.assertEqual(config_dir, legacy_dir)
|
||||
self.assertEqual(data_dir, legacy_dir)
|
||||
self.assertEqual(state_dir, legacy_dir)
|
||||
|
||||
# Should be able to access existing files
|
||||
yaml_file = legacy_dir / "distributions" / "ollama" / "ollama-run.yaml"
|
||||
self.assertTrue(yaml_file.exists())
|
||||
|
||||
# Should be able to parse existing configuration
|
||||
with open(yaml_file) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
self.assertEqual(config["version"], 2)
|
||||
|
||||
def test_error_recovery_scenarios(self):
|
||||
"""Test error recovery in various scenarios."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Scenario 1: Partial migration failure
|
||||
self.create_realistic_legacy_structure(base_dir)
|
||||
|
||||
# Create conflicting file in target location
|
||||
config_dir = base_dir / ".config" / "llama-stack"
|
||||
config_dir.mkdir(parents=True)
|
||||
|
||||
conflicting_file = config_dir / "distributions"
|
||||
conflicting_file.touch() # Create file instead of directory
|
||||
|
||||
from llama_stack.cli.migrate_xdg import migrate_to_xdg
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
|
||||
# Should handle conflicts gracefully
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
conflict_mentioned = any(
|
||||
"Warning" in call or "conflict" in call.lower() for call in print_calls
|
||||
)
|
||||
|
||||
# Migration should complete with warnings
|
||||
self.assertTrue(result or conflict_mentioned)
|
||||
|
||||
def test_cross_platform_compatibility(self):
|
||||
"""Test cross-platform compatibility of XDG implementation."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
# Test with different path separators and formats
|
||||
if os.name == "nt": # Windows
|
||||
# Test Windows-style paths
|
||||
os.environ["XDG_CONFIG_HOME"] = str(base_dir / "config").replace("/", "\\")
|
||||
os.environ["XDG_DATA_HOME"] = str(base_dir / "data").replace("/", "\\")
|
||||
else: # Unix-like
|
||||
# Test Unix-style paths
|
||||
os.environ["XDG_CONFIG_HOME"] = str(base_dir / "config")
|
||||
os.environ["XDG_DATA_HOME"] = str(base_dir / "data")
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
# Should work regardless of platform
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
data_dir = get_llama_stack_data_dir()
|
||||
|
||||
# Paths should be valid for the current platform
|
||||
self.assertTrue(config_dir.is_absolute())
|
||||
self.assertTrue(data_dir.is_absolute())
|
||||
|
||||
# Should be able to create directories
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.assertTrue(config_dir.exists())
|
||||
self.assertTrue(data_dir.exists())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
516
tests/integration/test_xdg_migration.py
Normal file
516
tests/integration/test_xdg_migration.py
Normal file
|
@ -0,0 +1,516 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from llama_stack.cli.migrate_xdg import migrate_to_xdg
|
||||
|
||||
|
||||
class TestXDGMigrationIntegration(unittest.TestCase):
|
||||
"""Integration tests for XDG migration functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
# Store original environment variables
|
||||
self.original_env = {}
|
||||
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
|
||||
# Clear environment variables
|
||||
for key in self.original_env:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
# Restore original environment variables
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def create_legacy_structure(self, base_dir: Path) -> Path:
|
||||
"""Create a realistic legacy ~/.llama directory structure."""
|
||||
legacy_dir = base_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
|
||||
# Create distributions
|
||||
distributions_dir = legacy_dir / "distributions"
|
||||
distributions_dir.mkdir()
|
||||
|
||||
# Create sample distribution
|
||||
ollama_dir = distributions_dir / "ollama"
|
||||
ollama_dir.mkdir()
|
||||
(ollama_dir / "ollama-run.yaml").write_text("version: 2\napis: []\n")
|
||||
(ollama_dir / "build.yaml").write_text("name: ollama\n")
|
||||
|
||||
# Create providers.d
|
||||
providers_dir = legacy_dir / "providers.d"
|
||||
providers_dir.mkdir()
|
||||
(providers_dir / "remote").mkdir()
|
||||
(providers_dir / "remote" / "inference").mkdir()
|
||||
(providers_dir / "remote" / "inference" / "custom.yaml").write_text("provider_type: remote::custom\n")
|
||||
|
||||
# Create checkpoints
|
||||
checkpoints_dir = legacy_dir / "checkpoints"
|
||||
checkpoints_dir.mkdir()
|
||||
model_dir = checkpoints_dir / "meta-llama" / "Llama-3.2-1B-Instruct"
|
||||
model_dir.mkdir(parents=True)
|
||||
(model_dir / "consolidated.00.pth").write_text("fake model weights")
|
||||
(model_dir / "params.json").write_text('{"dim": 2048}')
|
||||
|
||||
# Create runtime
|
||||
runtime_dir = legacy_dir / "runtime"
|
||||
runtime_dir.mkdir()
|
||||
(runtime_dir / "trace_store.db").write_text("fake sqlite database")
|
||||
|
||||
# Create some fake files in various subdirectories
|
||||
(legacy_dir / "config.json").write_text('{"version": "0.2.13"}')
|
||||
|
||||
return legacy_dir
|
||||
|
||||
def verify_xdg_structure(self, base_dir: Path, legacy_dir: Path):
|
||||
"""Verify that the XDG structure was created correctly."""
|
||||
config_dir = base_dir / ".config" / "llama-stack"
|
||||
data_dir = base_dir / ".local" / "share" / "llama-stack"
|
||||
state_dir = base_dir / ".local" / "state" / "llama-stack"
|
||||
|
||||
# Verify distributions moved to config
|
||||
self.assertTrue((config_dir / "distributions").exists())
|
||||
self.assertTrue((config_dir / "distributions" / "ollama" / "ollama-run.yaml").exists())
|
||||
|
||||
# Verify providers.d moved to config
|
||||
self.assertTrue((config_dir / "providers.d").exists())
|
||||
self.assertTrue((config_dir / "providers.d" / "remote" / "inference" / "custom.yaml").exists())
|
||||
|
||||
# Verify checkpoints moved to data
|
||||
self.assertTrue((data_dir / "checkpoints").exists())
|
||||
self.assertTrue(
|
||||
(data_dir / "checkpoints" / "meta-llama" / "Llama-3.2-1B-Instruct" / "consolidated.00.pth").exists()
|
||||
)
|
||||
|
||||
# Verify runtime moved to state
|
||||
self.assertTrue((state_dir / "runtime").exists())
|
||||
self.assertTrue((state_dir / "runtime" / "trace_store.db").exists())
|
||||
|
||||
def test_full_migration_workflow(self):
|
||||
"""Test complete migration workflow from legacy to XDG."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
# Set up fake home directory
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
# Verify legacy structure exists
|
||||
self.assertTrue(legacy_dir.exists())
|
||||
self.assertTrue((legacy_dir / "distributions").exists())
|
||||
self.assertTrue((legacy_dir / "checkpoints").exists())
|
||||
|
||||
# Perform migration with user confirmation
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify XDG structure was created
|
||||
self.verify_xdg_structure(base_dir, legacy_dir)
|
||||
|
||||
# Verify legacy directory was removed
|
||||
self.assertFalse(legacy_dir.exists())
|
||||
|
||||
def test_migration_dry_run(self):
|
||||
"""Test dry run migration (no actual file movement)."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
# Perform dry run migration
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_to_xdg(dry_run=True)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Check that dry run message was printed
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Dry run mode" in call for call in print_calls))
|
||||
|
||||
# Verify nothing was actually moved
|
||||
self.assertTrue(legacy_dir.exists())
|
||||
self.assertTrue((legacy_dir / "distributions").exists())
|
||||
self.assertFalse((base_dir / ".config" / "llama-stack").exists())
|
||||
|
||||
def test_migration_user_cancellation(self):
|
||||
"""Test migration when user cancels."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
# User cancels migration
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.return_value = "n"
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Verify nothing was moved
|
||||
self.assertTrue(legacy_dir.exists())
|
||||
self.assertTrue((legacy_dir / "distributions").exists())
|
||||
self.assertFalse((base_dir / ".config" / "llama-stack").exists())
|
||||
|
||||
def test_migration_with_existing_xdg_directories(self):
|
||||
"""Test migration when XDG directories already exist."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
# Create existing XDG structure with conflicting files
|
||||
config_dir = base_dir / ".config" / "llama-stack"
|
||||
config_dir.mkdir(parents=True)
|
||||
(config_dir / "distributions").mkdir()
|
||||
(config_dir / "distributions" / "existing.yaml").write_text("existing config")
|
||||
|
||||
# Perform migration
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Check that warning was printed
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Warning: Target already exists" in call for call in print_calls))
|
||||
|
||||
# Verify existing file wasn't overwritten
|
||||
self.assertTrue((config_dir / "distributions" / "existing.yaml").exists())
|
||||
|
||||
# Legacy distributions should still exist due to conflict
|
||||
self.assertTrue((legacy_dir / "distributions").exists())
|
||||
|
||||
def test_migration_partial_success(self):
|
||||
"""Test migration when some items succeed and others fail."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure
|
||||
self.create_legacy_structure(base_dir)
|
||||
|
||||
# Create readonly target directory to simulate permission error
|
||||
config_dir = base_dir / ".config" / "llama-stack"
|
||||
config_dir.mkdir(parents=True)
|
||||
distributions_target = config_dir / "distributions"
|
||||
distributions_target.mkdir()
|
||||
distributions_target.chmod(0o444) # Read-only
|
||||
|
||||
try:
|
||||
# Perform migration
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup
|
||||
|
||||
migrate_to_xdg(dry_run=False)
|
||||
# Should return True even with partial success
|
||||
|
||||
# Some items should have been migrated successfully
|
||||
self.assertTrue((base_dir / ".local" / "share" / "llama-stack" / "checkpoints").exists())
|
||||
|
||||
finally:
|
||||
# Restore permissions for cleanup
|
||||
distributions_target.chmod(0o755)
|
||||
|
||||
def test_migration_empty_legacy_directory(self):
|
||||
"""Test migration when legacy directory exists but is empty."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create empty legacy directory
|
||||
legacy_dir = base_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_migration_preserves_file_permissions(self):
|
||||
"""Test that migration preserves file permissions."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure with specific permissions
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
# Set specific permissions on a file
|
||||
config_file = legacy_dir / "distributions" / "ollama" / "ollama-run.yaml"
|
||||
config_file.chmod(0o600)
|
||||
original_stat = config_file.stat()
|
||||
|
||||
# Perform migration
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"]
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify permissions were preserved
|
||||
migrated_file = base_dir / ".config" / "llama-stack" / "distributions" / "ollama" / "ollama-run.yaml"
|
||||
self.assertTrue(migrated_file.exists())
|
||||
migrated_stat = migrated_file.stat()
|
||||
self.assertEqual(original_stat.st_mode, migrated_stat.st_mode)
|
||||
|
||||
def test_migration_preserves_directory_structure(self):
|
||||
"""Test that migration preserves complex directory structures."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create complex legacy structure
|
||||
legacy_dir = base_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
|
||||
# Create nested structure
|
||||
complex_path = legacy_dir / "checkpoints" / "org" / "model" / "variant" / "files"
|
||||
complex_path.mkdir(parents=True)
|
||||
(complex_path / "model.bin").write_text("model data")
|
||||
(complex_path / "config.json").write_text("config data")
|
||||
|
||||
# Perform migration
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"]
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify structure was preserved
|
||||
migrated_path = (
|
||||
base_dir
|
||||
/ ".local"
|
||||
/ "share"
|
||||
/ "llama-stack"
|
||||
/ "checkpoints"
|
||||
/ "org"
|
||||
/ "model"
|
||||
/ "variant"
|
||||
/ "files"
|
||||
)
|
||||
self.assertTrue(migrated_path.exists())
|
||||
self.assertTrue((migrated_path / "model.bin").exists())
|
||||
self.assertTrue((migrated_path / "config.json").exists())
|
||||
|
||||
def test_migration_with_symlinks(self):
|
||||
"""Test migration with symbolic links in legacy directory."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
# Create a symlink
|
||||
actual_file = legacy_dir / "actual_config.yaml"
|
||||
actual_file.write_text("actual config content")
|
||||
|
||||
symlink_file = legacy_dir / "distributions" / "symlinked_config.yaml"
|
||||
symlink_file.symlink_to(actual_file)
|
||||
|
||||
# Perform migration
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"]
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify symlink was preserved
|
||||
migrated_symlink = base_dir / ".config" / "llama-stack" / "distributions" / "symlinked_config.yaml"
|
||||
self.assertTrue(migrated_symlink.exists())
|
||||
self.assertTrue(migrated_symlink.is_symlink())
|
||||
|
||||
def test_migration_large_files(self):
|
||||
"""Test migration with large files (simulated)."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
# Create a larger file (1MB)
|
||||
large_file = legacy_dir / "checkpoints" / "large_model.bin"
|
||||
large_file.write_bytes(b"0" * (1024 * 1024))
|
||||
|
||||
# Perform migration
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"]
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify large file was moved correctly
|
||||
migrated_file = base_dir / ".local" / "share" / "llama-stack" / "checkpoints" / "large_model.bin"
|
||||
self.assertTrue(migrated_file.exists())
|
||||
self.assertEqual(migrated_file.stat().st_size, 1024 * 1024)
|
||||
|
||||
def test_migration_with_unicode_filenames(self):
|
||||
"""Test migration with unicode filenames."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy structure with unicode filenames
|
||||
legacy_dir = base_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
|
||||
unicode_dir = legacy_dir / "distributions" / "配置"
|
||||
unicode_dir.mkdir(parents=True)
|
||||
unicode_file = unicode_dir / "模型.yaml"
|
||||
unicode_file.write_text("unicode content")
|
||||
|
||||
# Perform migration
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"]
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify unicode files were migrated
|
||||
migrated_file = base_dir / ".config" / "llama-stack" / "distributions" / "配置" / "模型.yaml"
|
||||
self.assertTrue(migrated_file.exists())
|
||||
self.assertEqual(migrated_file.read_text(), "unicode content")
|
||||
|
||||
|
||||
class TestXDGMigrationCLI(unittest.TestCase):
|
||||
"""Test the CLI interface for XDG migration."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.original_env = {}
|
||||
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def test_cli_migrate_command_exists(self):
|
||||
"""Test that the migrate-xdg CLI command is properly registered."""
|
||||
from llama_stack.cli.llama import LlamaCLIParser
|
||||
|
||||
parser = LlamaCLIParser()
|
||||
|
||||
# Parse help to see if migrate-xdg is listed
|
||||
with patch("sys.argv", ["llama", "--help"]):
|
||||
with patch("sys.exit"):
|
||||
with patch("builtins.print") as mock_print:
|
||||
try:
|
||||
parser.parse_args()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
||||
# Check if migrate-xdg appears in help output
|
||||
help_output = "\n".join([call[0][0] for call in mock_print.call_args_list])
|
||||
self.assertIn("migrate-xdg", help_output)
|
||||
|
||||
def test_cli_migrate_dry_run(self):
|
||||
"""Test CLI migrate command with dry-run flag."""
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.migrate_xdg import MigrateXDG
|
||||
|
||||
# Create parser and add migrate command
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
MigrateXDG.create(subparsers)
|
||||
|
||||
# Test dry-run flag
|
||||
args = parser.parse_args(["migrate-xdg", "--dry-run"])
|
||||
self.assertTrue(args.dry_run)
|
||||
|
||||
# Test without dry-run flag
|
||||
args = parser.parse_args(["migrate-xdg"])
|
||||
self.assertFalse(args.dry_run)
|
||||
|
||||
def test_cli_migrate_execution(self):
|
||||
"""Test CLI migrate command execution."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create legacy directory
|
||||
legacy_dir = base_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
(legacy_dir / "test_file").touch()
|
||||
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.migrate_xdg import MigrateXDG
|
||||
|
||||
# Create parser and command
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
migrate_cmd = MigrateXDG.create(subparsers)
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args(["migrate-xdg", "--dry-run"])
|
||||
|
||||
# Execute command
|
||||
with patch("builtins.print") as mock_print:
|
||||
migrate_cmd._run_migrate_xdg_cmd(args)
|
||||
|
||||
# Verify output was printed
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Found legacy directory" in call for call in print_calls))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
164
tests/run_xdg_tests.py
Executable file
164
tests/run_xdg_tests.py
Executable file
|
@ -0,0 +1,164 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test runner for XDG Base Directory Specification compliance tests.
|
||||
|
||||
This script runs all XDG-related tests and provides a comprehensive report
|
||||
of the test results.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_test_suite():
|
||||
"""Run the complete XDG test suite."""
|
||||
|
||||
# Set up test environment
|
||||
test_dir = Path(__file__).parent
|
||||
project_root = test_dir.parent
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Test modules to run
|
||||
test_modules = [
|
||||
"tests.unit.test_xdg_compliance",
|
||||
"tests.unit.test_config_dirs",
|
||||
"tests.unit.cli.test_migrate_xdg",
|
||||
"tests.unit.test_template_xdg_paths",
|
||||
"tests.unit.test_xdg_edge_cases",
|
||||
"tests.integration.test_xdg_migration",
|
||||
"tests.integration.test_xdg_e2e",
|
||||
]
|
||||
|
||||
# Discover and run tests
|
||||
loader = unittest.TestLoader()
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
print("🔍 Discovering XDG compliance tests...")
|
||||
|
||||
for module_name in test_modules:
|
||||
try:
|
||||
# Try to load the test module
|
||||
module_suite = loader.loadTestsFromName(module_name)
|
||||
suite.addTest(module_suite)
|
||||
print(f" ✅ Loaded {module_name}")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Failed to load {module_name}: {e}")
|
||||
|
||||
# Run the tests
|
||||
print("\n🧪 Running XDG compliance tests...")
|
||||
print("=" * 60)
|
||||
|
||||
runner = unittest.TextTestRunner(verbosity=2, stream=sys.stdout, descriptions=True, buffer=True)
|
||||
|
||||
result = runner.run(suite)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 Test Summary:")
|
||||
print(f" Tests run: {result.testsRun}")
|
||||
print(f" Failures: {len(result.failures)}")
|
||||
print(f" Errors: {len(result.errors)}")
|
||||
print(f" Skipped: {len(result.skipped)}")
|
||||
|
||||
if result.failures:
|
||||
print("\n❌ Failures:")
|
||||
for test, traceback in result.failures:
|
||||
print(f" - {test}: {traceback.split('AssertionError:')[-1].strip()}")
|
||||
|
||||
if result.errors:
|
||||
print("\n💥 Errors:")
|
||||
for test, traceback in result.errors:
|
||||
print(f" - {test}: {traceback.split('Exception:')[-1].strip()}")
|
||||
|
||||
if result.skipped:
|
||||
print("\n⏭️ Skipped:")
|
||||
for test, reason in result.skipped:
|
||||
print(f" - {test}: {reason}")
|
||||
|
||||
# Overall result
|
||||
if result.wasSuccessful():
|
||||
print("\n🎉 All XDG compliance tests passed!")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some XDG compliance tests failed.")
|
||||
return 1
|
||||
|
||||
|
||||
def run_quick_tests():
|
||||
"""Run a quick subset of critical XDG tests."""
|
||||
|
||||
print("🚀 Running quick XDG compliance tests...")
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Quick test: Basic XDG functionality
|
||||
try:
|
||||
from llama_stack.distribution.utils.xdg_utils import (
|
||||
get_llama_stack_config_dir,
|
||||
get_xdg_compliant_path,
|
||||
get_xdg_config_home,
|
||||
)
|
||||
|
||||
print(" ✅ XDG utilities import successfully")
|
||||
|
||||
# Test basic functionality
|
||||
config_home = get_xdg_config_home()
|
||||
llama_config = get_llama_stack_config_dir()
|
||||
compliant_path = get_xdg_compliant_path("config", "test")
|
||||
|
||||
print(f" ✅ XDG config home: {config_home}")
|
||||
print(f" ✅ Llama Stack config: {llama_config}")
|
||||
print(f" ✅ Compliant path: {compliant_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ XDG utilities failed: {e}")
|
||||
return 1
|
||||
|
||||
# Quick test: Config dirs integration
|
||||
try:
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
)
|
||||
|
||||
print(f" ✅ Config dirs integration: {LLAMA_STACK_CONFIG_DIR}")
|
||||
print(f" ✅ Checkpoint directory: {DEFAULT_CHECKPOINT_DIR}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Config dirs integration failed: {e}")
|
||||
return 1
|
||||
|
||||
# Quick test: CLI migrate command
|
||||
try:
|
||||
print(" ✅ CLI migrate command available")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ CLI migrate command failed: {e}")
|
||||
return 1
|
||||
|
||||
print("\n🎉 Quick XDG compliance tests passed!")
|
||||
return 0
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test runner entry point."""
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "--quick":
|
||||
return run_quick_tests()
|
||||
else:
|
||||
return run_test_suite()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
489
tests/unit/cli/test_migrate_xdg.py
Normal file
489
tests/unit/cli/test_migrate_xdg.py
Normal file
|
@ -0,0 +1,489 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from llama_stack.cli.migrate_xdg import MigrateXDG, migrate_to_xdg
|
||||
|
||||
|
||||
class TestMigrateXDGCLI(unittest.TestCase):
|
||||
"""Test the MigrateXDG CLI command."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.original_env = {}
|
||||
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def create_parser_with_migrate_cmd(self):
|
||||
"""Create parser with migrate-xdg command."""
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="command")
|
||||
migrate_cmd = MigrateXDG.create(subparsers)
|
||||
return parser, migrate_cmd
|
||||
|
||||
def test_migrate_xdg_command_creation(self):
|
||||
"""Test that MigrateXDG command can be created."""
|
||||
parser, migrate_cmd = self.create_parser_with_migrate_cmd()
|
||||
|
||||
self.assertIsInstance(migrate_cmd, MigrateXDG)
|
||||
self.assertEqual(migrate_cmd.parser.prog, "llama migrate-xdg")
|
||||
self.assertEqual(migrate_cmd.parser.description, "Migrate from legacy ~/.llama to XDG-compliant directories")
|
||||
|
||||
def test_migrate_xdg_argument_parsing(self):
|
||||
"""Test argument parsing for migrate-xdg command."""
|
||||
parser, _ = self.create_parser_with_migrate_cmd()
|
||||
|
||||
# Test with dry-run flag
|
||||
args = parser.parse_args(["migrate-xdg", "--dry-run"])
|
||||
self.assertEqual(args.command, "migrate-xdg")
|
||||
self.assertTrue(args.dry_run)
|
||||
|
||||
# Test without dry-run flag
|
||||
args = parser.parse_args(["migrate-xdg"])
|
||||
self.assertEqual(args.command, "migrate-xdg")
|
||||
self.assertFalse(args.dry_run)
|
||||
|
||||
def test_migrate_xdg_help_text(self):
|
||||
"""Test help text for migrate-xdg command."""
|
||||
parser, _ = self.create_parser_with_migrate_cmd()
|
||||
|
||||
# Capture help output
|
||||
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
|
||||
with patch("sys.exit"):
|
||||
try:
|
||||
parser.parse_args(["migrate-xdg", "--help"])
|
||||
except SystemExit:
|
||||
pass
|
||||
|
||||
help_text = mock_stdout.getvalue()
|
||||
self.assertIn("migrate-xdg", help_text)
|
||||
self.assertIn("XDG-compliant directories", help_text)
|
||||
self.assertIn("--dry-run", help_text)
|
||||
|
||||
def test_migrate_xdg_command_execution_no_legacy(self):
|
||||
"""Test command execution when no legacy directory exists."""
|
||||
parser, migrate_cmd = self.create_parser_with_migrate_cmd()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = Path(temp_dir)
|
||||
|
||||
args = parser.parse_args(["migrate-xdg"])
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_cmd._run_migrate_xdg_cmd(args)
|
||||
|
||||
# Should succeed when no migration needed
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
# Should print appropriate message
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("No legacy directory found" in call for call in print_calls))
|
||||
|
||||
def test_migrate_xdg_command_execution_with_legacy(self):
|
||||
"""Test command execution when legacy directory exists."""
|
||||
parser, migrate_cmd = self.create_parser_with_migrate_cmd()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = base_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
(legacy_dir / "test_file").touch()
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
args = parser.parse_args(["migrate-xdg"])
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_cmd._run_migrate_xdg_cmd(args)
|
||||
|
||||
# Should succeed
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
# Should print migration information
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Found legacy directory" in call for call in print_calls))
|
||||
|
||||
def test_migrate_xdg_command_execution_dry_run(self):
|
||||
"""Test command execution with dry-run flag."""
|
||||
parser, migrate_cmd = self.create_parser_with_migrate_cmd()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = base_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
(legacy_dir / "test_file").touch()
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
args = parser.parse_args(["migrate-xdg", "--dry-run"])
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_cmd._run_migrate_xdg_cmd(args)
|
||||
|
||||
# Should succeed
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
# Should print dry-run information
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Dry run mode" in call for call in print_calls))
|
||||
|
||||
def test_migrate_xdg_command_execution_error_handling(self):
|
||||
"""Test command execution with error handling."""
|
||||
parser, migrate_cmd = self.create_parser_with_migrate_cmd()
|
||||
|
||||
args = parser.parse_args(["migrate-xdg"])
|
||||
|
||||
# Mock migrate_to_xdg to raise an exception
|
||||
with patch("llama_stack.cli.migrate_xdg.migrate_to_xdg") as mock_migrate:
|
||||
mock_migrate.side_effect = Exception("Test error")
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_cmd._run_migrate_xdg_cmd(args)
|
||||
|
||||
# Should return error code
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
# Should print error message
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Error during migration" in call for call in print_calls))
|
||||
|
||||
def test_migrate_xdg_command_integration(self):
|
||||
"""Test full integration of migrate-xdg command."""
|
||||
from llama_stack.cli.llama import LlamaCLIParser
|
||||
|
||||
# Create main parser
|
||||
main_parser = LlamaCLIParser()
|
||||
|
||||
# Test that migrate-xdg is in the subcommands
|
||||
with patch("sys.argv", ["llama", "migrate-xdg", "--help"]):
|
||||
with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
|
||||
with patch("sys.exit"):
|
||||
try:
|
||||
main_parser.parse_args()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
||||
help_text = mock_stdout.getvalue()
|
||||
self.assertIn("migrate-xdg", help_text)
|
||||
|
||||
|
||||
class TestMigrateXDGFunction(unittest.TestCase):
|
||||
"""Test the migrate_to_xdg function directly."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.original_env = {}
|
||||
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def create_legacy_structure(self, base_dir: Path) -> Path:
|
||||
"""Create a test legacy directory structure."""
|
||||
legacy_dir = base_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
|
||||
# Create distributions
|
||||
(legacy_dir / "distributions").mkdir()
|
||||
(legacy_dir / "distributions" / "ollama").mkdir()
|
||||
(legacy_dir / "distributions" / "ollama" / "run.yaml").write_text("version: 2\n")
|
||||
|
||||
# Create checkpoints
|
||||
(legacy_dir / "checkpoints").mkdir()
|
||||
(legacy_dir / "checkpoints" / "model.bin").write_text("fake model")
|
||||
|
||||
# Create providers.d
|
||||
(legacy_dir / "providers.d").mkdir()
|
||||
(legacy_dir / "providers.d" / "provider.yaml").write_text("provider: test\n")
|
||||
|
||||
# Create runtime
|
||||
(legacy_dir / "runtime").mkdir()
|
||||
(legacy_dir / "runtime" / "trace.db").write_text("fake database")
|
||||
|
||||
return legacy_dir
|
||||
|
||||
def test_migrate_to_xdg_no_legacy_directory(self):
|
||||
"""Test migrate_to_xdg when no legacy directory exists."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = Path(temp_dir)
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_migrate_to_xdg_dry_run(self):
|
||||
"""Test migrate_to_xdg with dry_run=True."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_to_xdg(dry_run=True)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Should print dry run information
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Dry run mode" in call for call in print_calls))
|
||||
|
||||
# Legacy directory should still exist
|
||||
self.assertTrue(legacy_dir.exists())
|
||||
|
||||
def test_migrate_to_xdg_user_confirms(self):
|
||||
"""Test migrate_to_xdg when user confirms migration."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Legacy directory should be removed
|
||||
self.assertFalse(legacy_dir.exists())
|
||||
|
||||
# XDG directories should be created
|
||||
self.assertTrue((base_dir / ".config" / "llama-stack").exists())
|
||||
self.assertTrue((base_dir / ".local" / "share" / "llama-stack").exists())
|
||||
|
||||
def test_migrate_to_xdg_user_cancels(self):
|
||||
"""Test migrate_to_xdg when user cancels migration."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.return_value = "n" # Cancel migration
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Legacy directory should still exist
|
||||
self.assertTrue(legacy_dir.exists())
|
||||
|
||||
def test_migrate_to_xdg_partial_migration(self):
|
||||
"""Test migrate_to_xdg with partial migration (some files fail)."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
self.assertFalse(legacy_dir.exists())
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create target directory with conflicting file
|
||||
config_dir = base_dir / ".config" / "llama-stack"
|
||||
config_dir.mkdir(parents=True)
|
||||
(config_dir / "distributions").mkdir()
|
||||
(config_dir / "distributions" / "existing.yaml").write_text("existing")
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Should print warning about conflicts
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Warning: Target already exists" in call for call in print_calls))
|
||||
|
||||
def test_migrate_to_xdg_permission_error(self):
|
||||
"""Test migrate_to_xdg with permission errors."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
self.assertFalse(legacy_dir.exists())
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Create readonly target directory
|
||||
config_dir = base_dir / ".config" / "llama-stack"
|
||||
config_dir.mkdir(parents=True)
|
||||
config_dir.chmod(0o444) # Read-only
|
||||
|
||||
try:
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "n"] # Confirm migration, don't cleanup
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Should handle permission errors gracefully
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
# Should contain some error or warning message
|
||||
self.assertTrue(len(print_calls) > 0)
|
||||
|
||||
finally:
|
||||
# Restore permissions for cleanup
|
||||
config_dir.chmod(0o755)
|
||||
|
||||
def test_migrate_to_xdg_empty_legacy_directory(self):
|
||||
"""Test migrate_to_xdg with empty legacy directory."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = base_dir / ".llama"
|
||||
legacy_dir.mkdir() # Empty directory
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_migrate_to_xdg_preserves_file_content(self):
|
||||
"""Test that migrate_to_xdg preserves file content correctly."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
# Add specific content to test
|
||||
test_content = "test configuration content"
|
||||
(legacy_dir / "distributions" / "ollama" / "run.yaml").write_text(test_content)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Check content was preserved
|
||||
migrated_file = base_dir / ".config" / "llama-stack" / "distributions" / "ollama" / "run.yaml"
|
||||
self.assertTrue(migrated_file.exists())
|
||||
self.assertEqual(migrated_file.read_text(), test_content)
|
||||
|
||||
def test_migrate_to_xdg_with_symlinks(self):
|
||||
"""Test migrate_to_xdg with symbolic links."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
# Create symlink
|
||||
actual_file = legacy_dir / "actual_config.yaml"
|
||||
actual_file.write_text("actual config")
|
||||
|
||||
symlink_file = legacy_dir / "distributions" / "symlinked.yaml"
|
||||
symlink_file.symlink_to(actual_file)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Check symlink was preserved
|
||||
migrated_symlink = base_dir / ".config" / "llama-stack" / "distributions" / "symlinked.yaml"
|
||||
self.assertTrue(migrated_symlink.exists())
|
||||
self.assertTrue(migrated_symlink.is_symlink())
|
||||
|
||||
def test_migrate_to_xdg_nested_directory_structure(self):
|
||||
"""Test migrate_to_xdg with nested directory structures."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
# Create nested structure
|
||||
nested_dir = legacy_dir / "checkpoints" / "org" / "model" / "variant"
|
||||
nested_dir.mkdir(parents=True)
|
||||
(nested_dir / "model.bin").write_text("nested model")
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = ["y", "y"] # Confirm migration and cleanup
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Check nested structure was preserved
|
||||
migrated_nested = (
|
||||
base_dir / ".local" / "share" / "llama-stack" / "checkpoints" / "org" / "model" / "variant"
|
||||
)
|
||||
self.assertTrue(migrated_nested.exists())
|
||||
self.assertTrue((migrated_nested / "model.bin").exists())
|
||||
|
||||
def test_migrate_to_xdg_user_input_variations(self):
|
||||
"""Test migrate_to_xdg with various user input variations."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir)
|
||||
legacy_dir = self.create_legacy_structure(base_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = base_dir
|
||||
|
||||
# Test various forms of "yes"
|
||||
for yes_input in ["y", "Y", "yes", "Yes", "YES"]:
|
||||
# Recreate legacy directory for each test
|
||||
if legacy_dir.exists():
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(legacy_dir)
|
||||
self.create_legacy_structure(base_dir)
|
||||
|
||||
with patch("builtins.input") as mock_input:
|
||||
mock_input.side_effect = [yes_input, "n"] # Confirm migration, don't cleanup
|
||||
|
||||
result = migrate_to_xdg(dry_run=False)
|
||||
self.assertTrue(result, f"Failed with input: {yes_input}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
418
tests/unit/test_config_dirs.py
Normal file
418
tests/unit/test_config_dirs.py
Normal file
|
@ -0,0 +1,418 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
# Import after we set up environment to avoid module-level imports affecting tests
|
||||
class TestConfigDirs(unittest.TestCase):
|
||||
"""Test the config_dirs module with XDG compliance and backwards compatibility."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
# Store original environment variables
|
||||
self.original_env = {}
|
||||
self.env_vars = [
|
||||
"XDG_CONFIG_HOME",
|
||||
"XDG_DATA_HOME",
|
||||
"XDG_STATE_HOME",
|
||||
"XDG_CACHE_HOME",
|
||||
"LLAMA_STACK_CONFIG_DIR",
|
||||
"SQLITE_STORE_DIR",
|
||||
"FILES_STORAGE_DIR",
|
||||
]
|
||||
|
||||
for key in self.env_vars:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
# Restore original environment variables
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
# Clear module cache to ensure fresh imports
|
||||
import sys
|
||||
|
||||
modules_to_clear = ["llama_stack.distribution.utils.config_dirs", "llama_stack.distribution.utils.xdg_utils"]
|
||||
for module in modules_to_clear:
|
||||
if module in sys.modules:
|
||||
del sys.modules[module]
|
||||
|
||||
def clear_env_vars(self):
|
||||
"""Clear all relevant environment variables."""
|
||||
for key in self.env_vars:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_config_dirs_xdg_defaults(self):
|
||||
"""Test config_dirs with XDG default paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Mock that no legacy directory exists
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = Path("/home/testuser")
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
# Import after setting up mocks
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
EXTERNAL_PROVIDERS_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
RUNTIME_BASE_DIR,
|
||||
)
|
||||
|
||||
# Verify XDG-compliant paths
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/home/testuser/.config/llama-stack"))
|
||||
self.assertEqual(DEFAULT_CHECKPOINT_DIR, Path("/home/testuser/.local/share/llama-stack/checkpoints"))
|
||||
self.assertEqual(RUNTIME_BASE_DIR, Path("/home/testuser/.local/state/llama-stack/runtime"))
|
||||
self.assertEqual(EXTERNAL_PROVIDERS_DIR, Path("/home/testuser/.config/llama-stack/providers.d"))
|
||||
self.assertEqual(DISTRIBS_BASE_DIR, Path("/home/testuser/.config/llama-stack/distributions"))
|
||||
|
||||
def test_config_dirs_custom_xdg_paths(self):
|
||||
"""Test config_dirs with custom XDG paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Set custom XDG paths
|
||||
os.environ["XDG_CONFIG_HOME"] = "/custom/config"
|
||||
os.environ["XDG_DATA_HOME"] = "/custom/data"
|
||||
os.environ["XDG_STATE_HOME"] = "/custom/state"
|
||||
|
||||
# Mock that no legacy directory exists
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
EXTERNAL_PROVIDERS_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
RUNTIME_BASE_DIR,
|
||||
)
|
||||
|
||||
# Verify custom XDG paths are used
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/custom/config/llama-stack"))
|
||||
self.assertEqual(DEFAULT_CHECKPOINT_DIR, Path("/custom/data/llama-stack/checkpoints"))
|
||||
self.assertEqual(RUNTIME_BASE_DIR, Path("/custom/state/llama-stack/runtime"))
|
||||
self.assertEqual(EXTERNAL_PROVIDERS_DIR, Path("/custom/config/llama-stack/providers.d"))
|
||||
self.assertEqual(DISTRIBS_BASE_DIR, Path("/custom/config/llama-stack/distributions"))
|
||||
|
||||
def test_config_dirs_legacy_environment_variable(self):
|
||||
"""Test config_dirs with legacy LLAMA_STACK_CONFIG_DIR."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Set legacy environment variable
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/legacy/llama"
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
EXTERNAL_PROVIDERS_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
RUNTIME_BASE_DIR,
|
||||
)
|
||||
|
||||
# All paths should use the legacy base
|
||||
legacy_base = Path("/legacy/llama")
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, legacy_base)
|
||||
self.assertEqual(DEFAULT_CHECKPOINT_DIR, legacy_base / "checkpoints")
|
||||
self.assertEqual(RUNTIME_BASE_DIR, legacy_base / "runtime")
|
||||
self.assertEqual(EXTERNAL_PROVIDERS_DIR, legacy_base / "providers.d")
|
||||
self.assertEqual(DISTRIBS_BASE_DIR, legacy_base / "distributions")
|
||||
|
||||
def test_config_dirs_legacy_directory_exists(self):
|
||||
"""Test config_dirs when legacy ~/.llama directory exists."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
home_dir = Path(temp_dir)
|
||||
legacy_dir = home_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
(legacy_dir / "test_file").touch() # Add content
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = home_dir
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
EXTERNAL_PROVIDERS_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
RUNTIME_BASE_DIR,
|
||||
)
|
||||
|
||||
# Should use legacy directory
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, legacy_dir)
|
||||
self.assertEqual(DEFAULT_CHECKPOINT_DIR, legacy_dir / "checkpoints")
|
||||
self.assertEqual(RUNTIME_BASE_DIR, legacy_dir / "runtime")
|
||||
self.assertEqual(EXTERNAL_PROVIDERS_DIR, legacy_dir / "providers.d")
|
||||
self.assertEqual(DISTRIBS_BASE_DIR, legacy_dir / "distributions")
|
||||
|
||||
def test_config_dirs_precedence_order(self):
|
||||
"""Test precedence order: LLAMA_STACK_CONFIG_DIR > legacy directory > XDG."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
home_dir = Path(temp_dir)
|
||||
legacy_dir = home_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
(legacy_dir / "test_file").touch()
|
||||
|
||||
# Set both legacy env var and XDG vars
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/priority/path"
|
||||
os.environ["XDG_CONFIG_HOME"] = "/custom/config"
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = home_dir
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR
|
||||
|
||||
# Legacy env var should take precedence
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/priority/path"))
|
||||
|
||||
def test_config_dirs_all_path_types(self):
|
||||
"""Test that all path objects are of correct type and absolute."""
|
||||
self.clear_env_vars()
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
EXTERNAL_PROVIDERS_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
RUNTIME_BASE_DIR,
|
||||
)
|
||||
|
||||
# All should be Path objects
|
||||
paths = [
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
RUNTIME_BASE_DIR,
|
||||
EXTERNAL_PROVIDERS_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
]
|
||||
|
||||
for path in paths:
|
||||
self.assertIsInstance(path, Path, f"Path {path} should be Path object")
|
||||
self.assertTrue(path.is_absolute(), f"Path {path} should be absolute")
|
||||
|
||||
def test_config_dirs_directory_relationships(self):
|
||||
"""Test relationships between different directory paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DISTRIBS_BASE_DIR,
|
||||
EXTERNAL_PROVIDERS_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
)
|
||||
|
||||
# Test parent-child relationships
|
||||
self.assertEqual(EXTERNAL_PROVIDERS_DIR.parent, LLAMA_STACK_CONFIG_DIR)
|
||||
self.assertEqual(DISTRIBS_BASE_DIR.parent, LLAMA_STACK_CONFIG_DIR)
|
||||
|
||||
# Test expected subdirectory names
|
||||
self.assertEqual(EXTERNAL_PROVIDERS_DIR.name, "providers.d")
|
||||
self.assertEqual(DISTRIBS_BASE_DIR.name, "distributions")
|
||||
|
||||
def test_config_dirs_environment_isolation(self):
|
||||
"""Test that config_dirs is properly isolated between tests."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# First import with one set of environment variables
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/first/path"
|
||||
|
||||
# Clear module cache
|
||||
import sys
|
||||
|
||||
if "llama_stack.distribution.utils.config_dirs" in sys.modules:
|
||||
del sys.modules["llama_stack.distribution.utils.config_dirs"]
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as FIRST_CONFIG
|
||||
|
||||
# Change environment and re-import
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/second/path"
|
||||
|
||||
# Clear module cache again
|
||||
if "llama_stack.distribution.utils.config_dirs" in sys.modules:
|
||||
del sys.modules["llama_stack.distribution.utils.config_dirs"]
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as SECOND_CONFIG
|
||||
|
||||
# Should get different paths
|
||||
self.assertEqual(FIRST_CONFIG, Path("/first/path"))
|
||||
self.assertEqual(SECOND_CONFIG, Path("/second/path"))
|
||||
|
||||
def test_config_dirs_with_tilde_expansion(self):
|
||||
"""Test config_dirs with tilde in paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = "~/custom_llama"
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR
|
||||
|
||||
# Should expand tilde
|
||||
expected = Path.home() / "custom_llama"
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, expected)
|
||||
|
||||
def test_config_dirs_empty_environment_variables(self):
|
||||
"""Test config_dirs with empty environment variables."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Set empty values
|
||||
os.environ["XDG_CONFIG_HOME"] = ""
|
||||
os.environ["XDG_DATA_HOME"] = ""
|
||||
|
||||
# Mock no legacy directory
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = Path("/home/testuser")
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
)
|
||||
|
||||
# Should fall back to defaults
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path("/home/testuser/.config/llama-stack"))
|
||||
self.assertEqual(DEFAULT_CHECKPOINT_DIR, Path("/home/testuser/.local/share/llama-stack/checkpoints"))
|
||||
|
||||
def test_config_dirs_relative_paths(self):
|
||||
"""Test config_dirs with relative paths in environment variables."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
os.chdir(temp_dir)
|
||||
|
||||
# Use relative path
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = "relative/config"
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR
|
||||
|
||||
# Should be resolved to absolute path
|
||||
self.assertTrue(LLAMA_STACK_CONFIG_DIR.is_absolute())
|
||||
self.assertTrue(str(LLAMA_STACK_CONFIG_DIR).endswith("relative/config"))
|
||||
|
||||
def test_config_dirs_with_spaces_in_paths(self):
|
||||
"""Test config_dirs with spaces in directory paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
path_with_spaces = "/path with spaces/llama config"
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = path_with_spaces
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR
|
||||
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path(path_with_spaces))
|
||||
|
||||
def test_config_dirs_unicode_paths(self):
|
||||
"""Test config_dirs with unicode characters in paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
unicode_path = "/配置/llama-stack"
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = unicode_path
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR
|
||||
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, Path(unicode_path))
|
||||
|
||||
def test_config_dirs_compatibility_import(self):
|
||||
"""Test that config_dirs can be imported without errors in various scenarios."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Test import with no environment variables
|
||||
try:
|
||||
# If we get here without exception, the import succeeded
|
||||
self.assertTrue(True)
|
||||
except Exception as e:
|
||||
self.fail(f"Import failed: {e}")
|
||||
|
||||
def test_config_dirs_multiple_imports(self):
|
||||
"""Test that multiple imports of config_dirs return consistent results."""
|
||||
self.clear_env_vars()
|
||||
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/consistent/path"
|
||||
|
||||
# First import
|
||||
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as FIRST_IMPORT
|
||||
|
||||
# Second import (should get cached result)
|
||||
from llama_stack.distribution.utils.config_dirs import LLAMA_STACK_CONFIG_DIR as SECOND_IMPORT
|
||||
|
||||
self.assertEqual(FIRST_IMPORT, SECOND_IMPORT)
|
||||
self.assertIs(FIRST_IMPORT, SECOND_IMPORT) # Should be the same object
|
||||
|
||||
|
||||
class TestConfigDirsIntegration(unittest.TestCase):
|
||||
"""Integration tests for config_dirs with other modules."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.original_env = {}
|
||||
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
# Clear module cache
|
||||
import sys
|
||||
|
||||
modules_to_clear = ["llama_stack.distribution.utils.config_dirs", "llama_stack.distribution.utils.xdg_utils"]
|
||||
for module in modules_to_clear:
|
||||
if module in sys.modules:
|
||||
del sys.modules[module]
|
||||
|
||||
def test_config_dirs_with_model_utils(self):
|
||||
"""Test that config_dirs works correctly with model_utils."""
|
||||
for key in self.original_env:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
# Test that model_local_dir uses the correct base directory
|
||||
model_descriptor = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
expected_path = str(DEFAULT_CHECKPOINT_DIR / model_descriptor.replace(":", "-"))
|
||||
actual_path = model_local_dir(model_descriptor)
|
||||
|
||||
self.assertEqual(actual_path, expected_path)
|
||||
|
||||
def test_config_dirs_consistency_across_modules(self):
|
||||
"""Test that all modules use consistent directory paths."""
|
||||
for key in self.original_env:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
RUNTIME_BASE_DIR,
|
||||
)
|
||||
from llama_stack.distribution.utils.xdg_utils import (
|
||||
get_llama_stack_config_dir,
|
||||
get_llama_stack_data_dir,
|
||||
get_llama_stack_state_dir,
|
||||
)
|
||||
|
||||
# Paths should be consistent between modules
|
||||
self.assertEqual(LLAMA_STACK_CONFIG_DIR, get_llama_stack_config_dir())
|
||||
self.assertEqual(DEFAULT_CHECKPOINT_DIR.parent, get_llama_stack_data_dir())
|
||||
self.assertEqual(RUNTIME_BASE_DIR.parent, get_llama_stack_state_dir())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
422
tests/unit/test_template_xdg_paths.py
Normal file
422
tests/unit/test_template_xdg_paths.py
Normal file
|
@ -0,0 +1,422 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import yaml
|
||||
|
||||
# Template imports will be tested through file system access
|
||||
|
||||
|
||||
class TestTemplateXDGPaths(unittest.TestCase):
|
||||
"""Test that templates use XDG-compliant paths correctly."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.original_env = {}
|
||||
self.env_vars = [
|
||||
"XDG_CONFIG_HOME",
|
||||
"XDG_DATA_HOME",
|
||||
"XDG_STATE_HOME",
|
||||
"XDG_CACHE_HOME",
|
||||
"LLAMA_STACK_CONFIG_DIR",
|
||||
"SQLITE_STORE_DIR",
|
||||
"FILES_STORAGE_DIR",
|
||||
]
|
||||
|
||||
for key in self.env_vars:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def clear_env_vars(self):
|
||||
"""Clear all relevant environment variables."""
|
||||
for key in self.env_vars:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_ollama_template_run_yaml_xdg_paths(self):
|
||||
"""Test that ollama template's run.yaml uses XDG environment variables."""
|
||||
template_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "run.yaml"
|
||||
|
||||
if not template_path.exists():
|
||||
self.skipTest("Ollama template not found")
|
||||
|
||||
content = template_path.read_text()
|
||||
|
||||
# Check for XDG-compliant environment variable references
|
||||
self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}", content)
|
||||
self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}", content)
|
||||
|
||||
# Check that paths use llama-stack directory
|
||||
self.assertIn("llama-stack", content)
|
||||
|
||||
# Check specific path patterns
|
||||
self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama", content)
|
||||
self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/ollama", content)
|
||||
|
||||
def test_ollama_template_run_yaml_parsing(self):
|
||||
"""Test that ollama template's run.yaml can be parsed correctly."""
|
||||
template_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "run.yaml"
|
||||
|
||||
if not template_path.exists():
|
||||
self.skipTest("Ollama template not found")
|
||||
|
||||
content = template_path.read_text()
|
||||
|
||||
# Replace environment variables with test values for parsing
|
||||
test_content = (
|
||||
content.replace("${env.XDG_STATE_HOME:-~/.local/state}", "/test/state")
|
||||
.replace("${env.XDG_DATA_HOME:-~/.local/share}", "/test/data")
|
||||
.replace(
|
||||
"${env.SQLITE_STORE_DIR:=${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/distributions/ollama}",
|
||||
"/test/state/llama-stack/distributions/ollama",
|
||||
)
|
||||
)
|
||||
|
||||
# Should be valid YAML
|
||||
try:
|
||||
yaml.safe_load(test_content)
|
||||
except yaml.YAMLError as e:
|
||||
self.fail(f"Template YAML is invalid: {e}")
|
||||
|
||||
def test_template_environment_variable_expansion(self):
|
||||
"""Test environment variable expansion in templates."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Set XDG variables
|
||||
os.environ["XDG_STATE_HOME"] = "/custom/state"
|
||||
os.environ["XDG_DATA_HOME"] = "/custom/data"
|
||||
|
||||
# Test pattern that should expand
|
||||
pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test"
|
||||
expected = "/custom/state/llama-stack/test"
|
||||
|
||||
# Mock environment variable expansion (this would normally be done by the shell)
|
||||
expanded = pattern.replace("${env.XDG_STATE_HOME:-~/.local/state}", os.environ["XDG_STATE_HOME"])
|
||||
self.assertEqual(expanded, expected)
|
||||
|
||||
def test_template_fallback_values(self):
|
||||
"""Test that templates have correct fallback values."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Test fallback pattern
|
||||
pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test"
|
||||
|
||||
# When environment variable is not set, should use fallback
|
||||
if "XDG_STATE_HOME" not in os.environ:
|
||||
# This is what the shell would do
|
||||
expanded = pattern.replace("${env.XDG_STATE_HOME:-~/.local/state}", "~/.local/state")
|
||||
expected = "~/.local/state/llama-stack/test"
|
||||
self.assertEqual(expanded, expected)
|
||||
|
||||
def test_ollama_template_python_config_xdg(self):
|
||||
"""Test that ollama template's Python config uses XDG-compliant paths."""
|
||||
template_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "ollama.py"
|
||||
|
||||
if not template_path.exists():
|
||||
self.skipTest("Ollama template Python file not found")
|
||||
|
||||
content = template_path.read_text()
|
||||
|
||||
# Check for XDG-compliant environment variable references
|
||||
self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}", content)
|
||||
self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}", content)
|
||||
|
||||
# Check that paths use llama-stack directory
|
||||
self.assertIn("llama-stack", content)
|
||||
|
||||
def test_template_path_consistency(self):
|
||||
"""Test that template paths are consistent across different files."""
|
||||
ollama_yaml_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "run.yaml"
|
||||
ollama_py_path = Path(__file__).parent.parent.parent / "llama_stack" / "templates" / "ollama" / "ollama.py"
|
||||
|
||||
if not ollama_yaml_path.exists() or not ollama_py_path.exists():
|
||||
self.skipTest("Ollama template files not found")
|
||||
|
||||
yaml_content = ollama_yaml_path.read_text()
|
||||
py_content = ollama_py_path.read_text()
|
||||
|
||||
# Both should use the same XDG environment variable patterns
|
||||
xdg_patterns = ["${env.XDG_STATE_HOME:-~/.local/state}", "${env.XDG_DATA_HOME:-~/.local/share}", "llama-stack"]
|
||||
|
||||
for pattern in xdg_patterns:
|
||||
self.assertIn(pattern, yaml_content, f"Pattern {pattern} not found in YAML")
|
||||
self.assertIn(pattern, py_content, f"Pattern {pattern} not found in Python")
|
||||
|
||||
def test_template_no_hardcoded_legacy_paths(self):
|
||||
"""Test that templates don't contain hardcoded legacy paths."""
|
||||
template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates"
|
||||
|
||||
if not template_dir.exists():
|
||||
self.skipTest("Templates directory not found")
|
||||
|
||||
# Check various template files
|
||||
for template_path in template_dir.rglob("*.yaml"):
|
||||
content = template_path.read_text()
|
||||
|
||||
# Should not contain hardcoded ~/.llama paths
|
||||
self.assertNotIn("~/.llama", content, f"Found hardcoded ~/.llama in {template_path}")
|
||||
|
||||
# Should not contain hardcoded /tmp paths for persistent data
|
||||
if "db_path" in content or "storage_dir" in content:
|
||||
self.assertNotIn("/tmp", content, f"Found hardcoded /tmp in {template_path}")
|
||||
|
||||
def test_template_environment_variable_format(self):
|
||||
"""Test that templates use correct environment variable format."""
|
||||
template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates"
|
||||
|
||||
if not template_dir.exists():
|
||||
self.skipTest("Templates directory not found")
|
||||
|
||||
# Pattern for XDG environment variables with fallbacks
|
||||
xdg_patterns = [
|
||||
"${env.XDG_CONFIG_HOME:-~/.config}",
|
||||
"${env.XDG_DATA_HOME:-~/.local/share}",
|
||||
"${env.XDG_STATE_HOME:-~/.local/state}",
|
||||
"${env.XDG_CACHE_HOME:-~/.cache}",
|
||||
]
|
||||
|
||||
for template_path in template_dir.rglob("*.yaml"):
|
||||
content = template_path.read_text()
|
||||
|
||||
# If XDG variables are used, they should have proper fallbacks
|
||||
for pattern in xdg_patterns:
|
||||
base_var = pattern.split(":-")[0] + "}"
|
||||
if base_var in content:
|
||||
self.assertIn(pattern, content, f"XDG variable without fallback in {template_path}")
|
||||
|
||||
def test_template_sqlite_store_dir_xdg(self):
|
||||
"""Test that SQLITE_STORE_DIR uses XDG-compliant fallback."""
|
||||
template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates"
|
||||
|
||||
if not template_dir.exists():
|
||||
self.skipTest("Templates directory not found")
|
||||
|
||||
for template_path in template_dir.rglob("*.yaml"):
|
||||
content = template_path.read_text()
|
||||
|
||||
if "SQLITE_STORE_DIR" in content:
|
||||
# Should use XDG fallback pattern
|
||||
self.assertIn("${env.XDG_STATE_HOME:-~/.local/state}", content)
|
||||
self.assertIn("llama-stack", content)
|
||||
|
||||
def test_template_files_storage_dir_xdg(self):
|
||||
"""Test that FILES_STORAGE_DIR uses XDG-compliant fallback."""
|
||||
template_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates"
|
||||
|
||||
if not template_dir.exists():
|
||||
self.skipTest("Templates directory not found")
|
||||
|
||||
for template_path in template_dir.rglob("*.yaml"):
|
||||
content = template_path.read_text()
|
||||
|
||||
if "FILES_STORAGE_DIR" in content:
|
||||
# Should use XDG fallback pattern
|
||||
self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}", content)
|
||||
self.assertIn("llama-stack", content)
|
||||
|
||||
|
||||
class TestTemplateCodeGeneration(unittest.TestCase):
|
||||
"""Test template code generation with XDG paths."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.original_env = {}
|
||||
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def test_provider_codegen_xdg_paths(self):
|
||||
"""Test that provider code generation uses XDG-compliant paths."""
|
||||
codegen_path = Path(__file__).parent.parent.parent / "scripts" / "provider_codegen.py"
|
||||
|
||||
if not codegen_path.exists():
|
||||
self.skipTest("Provider codegen script not found")
|
||||
|
||||
content = codegen_path.read_text()
|
||||
|
||||
# Should use XDG-compliant path in documentation
|
||||
self.assertIn("${env.XDG_DATA_HOME:-~/.local/share}/llama-stack", content)
|
||||
|
||||
# Should not use hardcoded ~/.llama paths
|
||||
self.assertNotIn("~/.llama/dummy", content)
|
||||
|
||||
def test_template_sample_config_paths(self):
|
||||
"""Test that template sample configs use XDG-compliant paths."""
|
||||
# This test checks that when templates generate sample configs,
|
||||
# they use XDG-compliant paths
|
||||
|
||||
# Mock a template that generates sample config
|
||||
with patch("llama_stack.templates.template.Template") as mock_template:
|
||||
mock_instance = MagicMock()
|
||||
mock_template.return_value = mock_instance
|
||||
|
||||
# Mock sample config generation
|
||||
def mock_sample_config(distro_dir):
|
||||
# Should use XDG-compliant path structure
|
||||
self.assertIn("llama-stack", distro_dir)
|
||||
return {"config": "test"}
|
||||
|
||||
mock_instance.sample_run_config = mock_sample_config
|
||||
|
||||
# Test sample config generation
|
||||
template = mock_template()
|
||||
template.sample_run_config("${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/test")
|
||||
|
||||
def test_template_path_substitution(self):
|
||||
"""Test that template path substitution works correctly."""
|
||||
# Test path substitution in template generation
|
||||
|
||||
original_path = "~/.llama/distributions/test"
|
||||
|
||||
# Should be converted to XDG-compliant path
|
||||
xdg_path = original_path.replace("~/.llama", "${env.XDG_DATA_HOME:-~/.local/share}/llama-stack")
|
||||
expected = "${env.XDG_DATA_HOME:-~/.local/share}/llama-stack/distributions/test"
|
||||
|
||||
self.assertEqual(xdg_path, expected)
|
||||
|
||||
def test_template_environment_variable_precedence(self):
|
||||
"""Test environment variable precedence in templates."""
|
||||
# Test that custom XDG variables take precedence over defaults
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"env": {"XDG_STATE_HOME": "/custom/state"},
|
||||
"pattern": "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test",
|
||||
"expected": "/custom/state/llama-stack/test",
|
||||
},
|
||||
{
|
||||
"env": {}, # No XDG variable set
|
||||
"pattern": "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test",
|
||||
"expected": "~/.local/state/llama-stack/test",
|
||||
},
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
# Clear environment
|
||||
for key in ["XDG_STATE_HOME", "XDG_DATA_HOME", "XDG_CONFIG_HOME"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Set test environment
|
||||
for key, value in case["env"].items():
|
||||
os.environ[key] = value
|
||||
|
||||
# Simulate shell variable expansion
|
||||
pattern = case["pattern"]
|
||||
for key, value in case["env"].items():
|
||||
var_pattern = f"${{env.{key}:-"
|
||||
if var_pattern in pattern:
|
||||
# Replace with actual value
|
||||
pattern = pattern.replace(f"${{env.{key}:-~/.local/state}}", value)
|
||||
|
||||
# If no replacement happened, use fallback
|
||||
if "${env.XDG_STATE_HOME:-~/.local/state}" in pattern:
|
||||
pattern = pattern.replace("${env.XDG_STATE_HOME:-~/.local/state}", "~/.local/state")
|
||||
|
||||
self.assertEqual(pattern, case["expected"])
|
||||
|
||||
|
||||
class TestTemplateIntegration(unittest.TestCase):
|
||||
"""Integration tests for templates with XDG compliance."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.original_env = {}
|
||||
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def test_template_with_xdg_environment(self):
|
||||
"""Test template behavior with XDG environment variables set."""
|
||||
# Clear environment
|
||||
for key in self.original_env:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Set custom XDG variables
|
||||
os.environ["XDG_CONFIG_HOME"] = "/custom/config"
|
||||
os.environ["XDG_DATA_HOME"] = "/custom/data"
|
||||
os.environ["XDG_STATE_HOME"] = "/custom/state"
|
||||
|
||||
# Test that template paths would resolve correctly
|
||||
# (This is a conceptual test since actual shell expansion happens at runtime)
|
||||
|
||||
template_pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test"
|
||||
|
||||
# In a real shell, this would expand to:
|
||||
|
||||
# Verify the pattern structure is correct
|
||||
self.assertIn("XDG_STATE_HOME", template_pattern)
|
||||
self.assertIn("llama-stack", template_pattern)
|
||||
self.assertIn("~/.local/state", template_pattern) # fallback
|
||||
|
||||
def test_template_with_no_xdg_environment(self):
|
||||
"""Test template behavior with no XDG environment variables."""
|
||||
# Clear all XDG environment variables
|
||||
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "XDG_CACHE_HOME"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# Test that templates would use fallback values
|
||||
template_pattern = "${env.XDG_STATE_HOME:-~/.local/state}/llama-stack/test"
|
||||
|
||||
# In a real shell with no XDG_STATE_HOME, this would expand to:
|
||||
|
||||
# Verify the pattern structure includes fallback
|
||||
self.assertIn(":-~/.local/state", template_pattern)
|
||||
|
||||
def test_template_consistency_across_providers(self):
|
||||
"""Test that all template providers use consistent XDG patterns."""
|
||||
templates_dir = Path(__file__).parent.parent.parent / "llama_stack" / "templates"
|
||||
|
||||
if not templates_dir.exists():
|
||||
self.skipTest("Templates directory not found")
|
||||
|
||||
# Expected XDG patterns that should be consistent across templates
|
||||
|
||||
# Check a few different provider templates
|
||||
provider_templates = []
|
||||
for provider_dir in templates_dir.iterdir():
|
||||
if provider_dir.is_dir() and not provider_dir.name.startswith("."):
|
||||
run_yaml = provider_dir / "run.yaml"
|
||||
if run_yaml.exists():
|
||||
provider_templates.append(run_yaml)
|
||||
|
||||
if not provider_templates:
|
||||
self.skipTest("No provider templates found")
|
||||
|
||||
# Check that templates use consistent patterns
|
||||
for template_path in provider_templates[:3]: # Check first 3 templates
|
||||
content = template_path.read_text()
|
||||
|
||||
# Should use llama-stack in paths
|
||||
if any(xdg_var in content for xdg_var in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME"]):
|
||||
self.assertIn("llama-stack", content, f"Template {template_path} uses XDG but not llama-stack")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
610
tests/unit/test_xdg_compliance.py
Normal file
610
tests/unit/test_xdg_compliance.py
Normal file
|
@ -0,0 +1,610 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from llama_stack.distribution.utils.xdg_utils import (
|
||||
ensure_directory_exists,
|
||||
get_llama_stack_cache_dir,
|
||||
get_llama_stack_config_dir,
|
||||
get_llama_stack_data_dir,
|
||||
get_llama_stack_state_dir,
|
||||
get_xdg_cache_home,
|
||||
get_xdg_compliant_path,
|
||||
get_xdg_config_home,
|
||||
get_xdg_data_home,
|
||||
get_xdg_state_home,
|
||||
migrate_legacy_directory,
|
||||
)
|
||||
|
||||
|
||||
class TestXDGCompliance(unittest.TestCase):
|
||||
"""Comprehensive test suite for XDG Base Directory Specification compliance."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
# Store original environment variables
|
||||
self.original_env = {}
|
||||
self.xdg_vars = ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_CACHE_HOME", "XDG_STATE_HOME"]
|
||||
self.llama_vars = ["LLAMA_STACK_CONFIG_DIR", "SQLITE_STORE_DIR", "FILES_STORAGE_DIR"]
|
||||
|
||||
for key in self.xdg_vars + self.llama_vars:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
# Restore original environment variables
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def clear_env_vars(self, vars_to_clear=None):
|
||||
"""Helper to clear environment variables."""
|
||||
if vars_to_clear is None:
|
||||
vars_to_clear = self.xdg_vars + self.llama_vars
|
||||
|
||||
for key in vars_to_clear:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_xdg_defaults(self):
|
||||
"""Test that XDG directories use correct defaults when no env vars are set."""
|
||||
self.clear_env_vars()
|
||||
home = Path.home()
|
||||
|
||||
self.assertEqual(get_xdg_config_home(), home / ".config")
|
||||
self.assertEqual(get_xdg_data_home(), home / ".local" / "share")
|
||||
self.assertEqual(get_xdg_cache_home(), home / ".cache")
|
||||
self.assertEqual(get_xdg_state_home(), home / ".local" / "state")
|
||||
|
||||
def test_xdg_custom_paths(self):
|
||||
"""Test that custom XDG paths are respected."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
os.environ["XDG_CONFIG_HOME"] = str(temp_path / "config")
|
||||
os.environ["XDG_DATA_HOME"] = str(temp_path / "data")
|
||||
os.environ["XDG_CACHE_HOME"] = str(temp_path / "cache")
|
||||
os.environ["XDG_STATE_HOME"] = str(temp_path / "state")
|
||||
|
||||
self.assertEqual(get_xdg_config_home(), temp_path / "config")
|
||||
self.assertEqual(get_xdg_data_home(), temp_path / "data")
|
||||
self.assertEqual(get_xdg_cache_home(), temp_path / "cache")
|
||||
self.assertEqual(get_xdg_state_home(), temp_path / "state")
|
||||
|
||||
def test_xdg_paths_with_tilde(self):
|
||||
"""Test XDG paths that use tilde expansion."""
|
||||
os.environ["XDG_CONFIG_HOME"] = "~/custom_config"
|
||||
os.environ["XDG_DATA_HOME"] = "~/custom_data"
|
||||
|
||||
home = Path.home()
|
||||
self.assertEqual(get_xdg_config_home(), home / "custom_config")
|
||||
self.assertEqual(get_xdg_data_home(), home / "custom_data")
|
||||
|
||||
def test_xdg_paths_relative(self):
|
||||
"""Test XDG paths with relative paths get resolved."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
os.chdir(temp_dir)
|
||||
os.environ["XDG_CONFIG_HOME"] = "relative_config"
|
||||
|
||||
# Should resolve relative to current directory
|
||||
result = get_xdg_config_home()
|
||||
self.assertTrue(result.is_absolute())
|
||||
self.assertTrue(str(result).endswith("relative_config"))
|
||||
|
||||
def test_llama_stack_directories_new_installation(self):
|
||||
"""Test llama-stack directories for new installations (no legacy directory)."""
|
||||
self.clear_env_vars()
|
||||
home = Path.home()
|
||||
|
||||
# Mock that ~/.llama doesn't exist
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = home
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
self.assertEqual(get_llama_stack_config_dir(), home / ".config" / "llama-stack")
|
||||
self.assertEqual(get_llama_stack_data_dir(), home / ".local" / "share" / "llama-stack")
|
||||
self.assertEqual(get_llama_stack_state_dir(), home / ".local" / "state" / "llama-stack")
|
||||
self.assertEqual(get_llama_stack_cache_dir(), home / ".cache" / "llama-stack")
|
||||
|
||||
def test_llama_stack_directories_with_custom_xdg(self):
|
||||
"""Test llama-stack directories with custom XDG paths."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
os.environ["XDG_CONFIG_HOME"] = str(temp_path / "config")
|
||||
os.environ["XDG_DATA_HOME"] = str(temp_path / "data")
|
||||
os.environ["XDG_STATE_HOME"] = str(temp_path / "state")
|
||||
os.environ["XDG_CACHE_HOME"] = str(temp_path / "cache")
|
||||
|
||||
# Mock that ~/.llama doesn't exist
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
self.assertEqual(get_llama_stack_config_dir(), temp_path / "config" / "llama-stack")
|
||||
self.assertEqual(get_llama_stack_data_dir(), temp_path / "data" / "llama-stack")
|
||||
self.assertEqual(get_llama_stack_state_dir(), temp_path / "state" / "llama-stack")
|
||||
self.assertEqual(get_llama_stack_cache_dir(), temp_path / "cache" / "llama-stack")
|
||||
|
||||
def test_legacy_environment_variable_precedence(self):
|
||||
"""Test that LLAMA_STACK_CONFIG_DIR takes highest precedence."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
legacy_path = Path(temp_dir) / "legacy"
|
||||
xdg_path = Path(temp_dir) / "xdg"
|
||||
|
||||
# Set both legacy and XDG variables
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = str(legacy_path)
|
||||
os.environ["XDG_CONFIG_HOME"] = str(xdg_path / "config")
|
||||
os.environ["XDG_DATA_HOME"] = str(xdg_path / "data")
|
||||
os.environ["XDG_STATE_HOME"] = str(xdg_path / "state")
|
||||
|
||||
# Legacy should take precedence for all directory types
|
||||
self.assertEqual(get_llama_stack_config_dir(), legacy_path)
|
||||
self.assertEqual(get_llama_stack_data_dir(), legacy_path)
|
||||
self.assertEqual(get_llama_stack_state_dir(), legacy_path)
|
||||
self.assertEqual(get_llama_stack_cache_dir(), legacy_path)
|
||||
|
||||
def test_legacy_directory_exists_and_has_content(self):
|
||||
"""Test that existing ~/.llama directory with content is used."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
home = Path(temp_dir)
|
||||
legacy_llama = home / ".llama"
|
||||
legacy_llama.mkdir()
|
||||
|
||||
# Create some content to simulate existing data
|
||||
(legacy_llama / "test_file").touch()
|
||||
(legacy_llama / "distributions").mkdir()
|
||||
|
||||
# Clear environment variables
|
||||
self.clear_env_vars()
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = home
|
||||
|
||||
self.assertEqual(get_llama_stack_config_dir(), legacy_llama)
|
||||
self.assertEqual(get_llama_stack_data_dir(), legacy_llama)
|
||||
self.assertEqual(get_llama_stack_state_dir(), legacy_llama)
|
||||
|
||||
def test_legacy_directory_exists_but_empty(self):
|
||||
"""Test that empty ~/.llama directory is ignored in favor of XDG."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
home = Path(temp_dir)
|
||||
legacy_llama = home / ".llama"
|
||||
legacy_llama.mkdir()
|
||||
# Don't add any content - directory is empty
|
||||
|
||||
self.clear_env_vars()
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = home
|
||||
|
||||
# Should use XDG paths since legacy directory is empty
|
||||
self.assertEqual(get_llama_stack_config_dir(), home / ".config" / "llama-stack")
|
||||
self.assertEqual(get_llama_stack_data_dir(), home / ".local" / "share" / "llama-stack")
|
||||
self.assertEqual(get_llama_stack_state_dir(), home / ".local" / "state" / "llama-stack")
|
||||
|
||||
def test_xdg_compliant_path_function(self):
|
||||
"""Test the get_xdg_compliant_path utility function."""
|
||||
self.clear_env_vars()
|
||||
home = Path.home()
|
||||
|
||||
# Mock that ~/.llama doesn't exist
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = home
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
self.assertEqual(get_xdg_compliant_path("config"), home / ".config" / "llama-stack")
|
||||
self.assertEqual(
|
||||
get_xdg_compliant_path("data", "models"), home / ".local" / "share" / "llama-stack" / "models"
|
||||
)
|
||||
self.assertEqual(
|
||||
get_xdg_compliant_path("state", "runtime"), home / ".local" / "state" / "llama-stack" / "runtime"
|
||||
)
|
||||
self.assertEqual(get_xdg_compliant_path("cache", "temp"), home / ".cache" / "llama-stack" / "temp")
|
||||
|
||||
def test_xdg_compliant_path_invalid_type(self):
|
||||
"""Test that invalid path types raise ValueError."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
get_xdg_compliant_path("invalid_type")
|
||||
|
||||
self.assertIn("Unknown path type", str(context.exception))
|
||||
self.assertIn("invalid_type", str(context.exception))
|
||||
|
||||
def test_xdg_compliant_path_with_subdirectory(self):
|
||||
"""Test get_xdg_compliant_path with various subdirectories."""
|
||||
self.clear_env_vars()
|
||||
home = Path.home()
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = home
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
# Test nested subdirectories
|
||||
self.assertEqual(
|
||||
get_xdg_compliant_path("data", "models/checkpoints"),
|
||||
home / ".local" / "share" / "llama-stack" / "models/checkpoints",
|
||||
)
|
||||
|
||||
# Test with Path object
|
||||
self.assertEqual(
|
||||
get_xdg_compliant_path("config", str(Path("distributions") / "ollama")),
|
||||
home / ".config" / "llama-stack" / "distributions" / "ollama",
|
||||
)
|
||||
|
||||
def test_ensure_directory_exists(self):
|
||||
"""Test the ensure_directory_exists utility function."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_path = Path(temp_dir) / "nested" / "directory" / "structure"
|
||||
|
||||
# Directory shouldn't exist initially
|
||||
self.assertFalse(test_path.exists())
|
||||
|
||||
# Create it
|
||||
ensure_directory_exists(test_path)
|
||||
|
||||
# Should exist now
|
||||
self.assertTrue(test_path.exists())
|
||||
self.assertTrue(test_path.is_dir())
|
||||
|
||||
def test_ensure_directory_exists_already_exists(self):
|
||||
"""Test ensure_directory_exists when directory already exists."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_path = Path(temp_dir) / "existing"
|
||||
test_path.mkdir()
|
||||
|
||||
# Should not raise an error
|
||||
ensure_directory_exists(test_path)
|
||||
self.assertTrue(test_path.exists())
|
||||
|
||||
def test_config_dirs_import_and_types(self):
|
||||
"""Test that the config_dirs module imports correctly and has proper types."""
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
DISTRIBS_BASE_DIR,
|
||||
EXTERNAL_PROVIDERS_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
RUNTIME_BASE_DIR,
|
||||
)
|
||||
|
||||
# All should be Path objects
|
||||
self.assertIsInstance(LLAMA_STACK_CONFIG_DIR, Path)
|
||||
self.assertIsInstance(DEFAULT_CHECKPOINT_DIR, Path)
|
||||
self.assertIsInstance(RUNTIME_BASE_DIR, Path)
|
||||
self.assertIsInstance(EXTERNAL_PROVIDERS_DIR, Path)
|
||||
self.assertIsInstance(DISTRIBS_BASE_DIR, Path)
|
||||
|
||||
# All should be absolute paths
|
||||
self.assertTrue(LLAMA_STACK_CONFIG_DIR.is_absolute())
|
||||
self.assertTrue(DEFAULT_CHECKPOINT_DIR.is_absolute())
|
||||
self.assertTrue(RUNTIME_BASE_DIR.is_absolute())
|
||||
self.assertTrue(EXTERNAL_PROVIDERS_DIR.is_absolute())
|
||||
self.assertTrue(DISTRIBS_BASE_DIR.is_absolute())
|
||||
|
||||
def test_config_dirs_proper_structure(self):
|
||||
"""Test that config_dirs uses proper XDG structure."""
|
||||
from llama_stack.distribution.utils.config_dirs import (
|
||||
DISTRIBS_BASE_DIR,
|
||||
EXTERNAL_PROVIDERS_DIR,
|
||||
LLAMA_STACK_CONFIG_DIR,
|
||||
)
|
||||
|
||||
# Check that paths contain expected components
|
||||
config_str = str(LLAMA_STACK_CONFIG_DIR)
|
||||
self.assertTrue(
|
||||
"llama-stack" in config_str or ".llama" in config_str,
|
||||
f"Config dir should contain 'llama-stack' or '.llama': {config_str}",
|
||||
)
|
||||
|
||||
# Test relationships between directories
|
||||
self.assertEqual(DISTRIBS_BASE_DIR, LLAMA_STACK_CONFIG_DIR / "distributions")
|
||||
self.assertEqual(EXTERNAL_PROVIDERS_DIR, LLAMA_STACK_CONFIG_DIR / "providers.d")
|
||||
|
||||
def test_environment_variable_combinations(self):
|
||||
"""Test various combinations of environment variables."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Test partial XDG variables
|
||||
os.environ["XDG_CONFIG_HOME"] = str(temp_path / "config")
|
||||
# Leave others as default
|
||||
self.clear_env_vars(["XDG_DATA_HOME", "XDG_STATE_HOME", "XDG_CACHE_HOME"])
|
||||
|
||||
home = Path.home()
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
self.assertEqual(get_llama_stack_config_dir(), temp_path / "config" / "llama-stack")
|
||||
self.assertEqual(get_llama_stack_data_dir(), home / ".local" / "share" / "llama-stack")
|
||||
self.assertEqual(get_llama_stack_state_dir(), home / ".local" / "state" / "llama-stack")
|
||||
|
||||
def test_migrate_legacy_directory_no_legacy(self):
|
||||
"""Test migration when no legacy directory exists."""
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = Path("/fake/home")
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
# Should return True (success) when no migration needed
|
||||
result = migrate_legacy_directory()
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_migrate_legacy_directory_exists(self):
|
||||
"""Test migration message when legacy directory exists."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
home = Path(temp_dir)
|
||||
legacy_llama = home / ".llama"
|
||||
legacy_llama.mkdir()
|
||||
(legacy_llama / "test_file").touch()
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = home
|
||||
with patch("builtins.print") as mock_print:
|
||||
result = migrate_legacy_directory()
|
||||
self.assertTrue(result)
|
||||
|
||||
# Check that migration information was printed
|
||||
print_calls = [call[0][0] for call in mock_print.call_args_list]
|
||||
self.assertTrue(any("Found legacy directory" in call for call in print_calls))
|
||||
self.assertTrue(any("Consider migrating" in call for call in print_calls))
|
||||
|
||||
def test_path_consistency_across_functions(self):
|
||||
"""Test that all path functions return consistent results."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
home = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = home
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
# All config-related functions should return the same base
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
config_path = get_xdg_compliant_path("config")
|
||||
self.assertEqual(config_dir, config_path)
|
||||
|
||||
# All data-related functions should return the same base
|
||||
data_dir = get_llama_stack_data_dir()
|
||||
data_path = get_xdg_compliant_path("data")
|
||||
self.assertEqual(data_dir, data_path)
|
||||
|
||||
def test_unicode_and_special_characters(self):
|
||||
"""Test XDG paths with unicode and special characters."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Test with unicode characters
|
||||
unicode_path = Path(temp_dir) / "配置" / "llama-stack"
|
||||
os.environ["XDG_CONFIG_HOME"] = str(unicode_path.parent)
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, unicode_path.parent)
|
||||
|
||||
# Test spaces in paths
|
||||
space_path = Path(temp_dir) / "my config"
|
||||
os.environ["XDG_CONFIG_HOME"] = str(space_path)
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, space_path)
|
||||
|
||||
def test_concurrent_access_safety(self):
|
||||
"""Test that XDG functions are safe for concurrent access."""
|
||||
import threading
|
||||
import time
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def worker():
|
||||
try:
|
||||
# Simulate concurrent access
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
time.sleep(0.01) # Small delay to increase chance of race conditions
|
||||
data_dir = get_llama_stack_data_dir()
|
||||
results.append((config_dir, data_dir))
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for _ in range(10):
|
||||
t = threading.Thread(target=worker)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
# Wait for all threads
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Check results
|
||||
self.assertEqual(len(errors), 0, f"Concurrent access errors: {errors}")
|
||||
self.assertEqual(len(results), 10)
|
||||
|
||||
# All results should be identical
|
||||
first_result = results[0]
|
||||
for result in results[1:]:
|
||||
self.assertEqual(result, first_result)
|
||||
|
||||
def test_symlink_handling(self):
|
||||
"""Test XDG path handling with symbolic links."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Create actual directory
|
||||
actual_dir = temp_path / "actual_config"
|
||||
actual_dir.mkdir()
|
||||
|
||||
# Create symlink
|
||||
symlink_dir = temp_path / "symlinked_config"
|
||||
symlink_dir.symlink_to(actual_dir)
|
||||
|
||||
os.environ["XDG_CONFIG_HOME"] = str(symlink_dir)
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, symlink_dir)
|
||||
|
||||
# Should resolve to actual path when needed
|
||||
self.assertTrue(result.exists())
|
||||
|
||||
def test_readonly_directory_handling(self):
|
||||
"""Test behavior when XDG directories are read-only."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
readonly_dir = temp_path / "readonly"
|
||||
readonly_dir.mkdir()
|
||||
|
||||
# Make directory read-only
|
||||
readonly_dir.chmod(0o444)
|
||||
|
||||
try:
|
||||
os.environ["XDG_CONFIG_HOME"] = str(readonly_dir)
|
||||
|
||||
# Should still return the path even if read-only
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, readonly_dir)
|
||||
|
||||
finally:
|
||||
# Restore permissions for cleanup
|
||||
readonly_dir.chmod(0o755)
|
||||
|
||||
def test_nonexistent_parent_directory(self):
|
||||
"""Test XDG paths with non-existent parent directories."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use a path with non-existent parents
|
||||
nonexistent_path = Path(temp_dir) / "does" / "not" / "exist" / "config"
|
||||
|
||||
os.environ["XDG_CONFIG_HOME"] = str(nonexistent_path)
|
||||
|
||||
# Should return the path even if it doesn't exist
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, nonexistent_path)
|
||||
|
||||
def test_env_var_expansion(self):
|
||||
"""Test environment variable expansion in XDG paths."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
os.environ["TEST_BASE"] = temp_dir
|
||||
os.environ["XDG_CONFIG_HOME"] = "$TEST_BASE/config"
|
||||
|
||||
# Path expansion should work
|
||||
result = get_xdg_config_home()
|
||||
expected = Path(temp_dir) / "config"
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
class TestXDGEdgeCases(unittest.TestCase):
|
||||
"""Test edge cases and error conditions for XDG compliance."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.original_env = {}
|
||||
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_CACHE_HOME", "XDG_STATE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def test_empty_environment_variables(self):
|
||||
"""Test behavior with empty environment variables."""
|
||||
# Set empty values
|
||||
os.environ["XDG_CONFIG_HOME"] = ""
|
||||
os.environ["XDG_DATA_HOME"] = ""
|
||||
|
||||
# Should fall back to defaults
|
||||
home = Path.home()
|
||||
self.assertEqual(get_xdg_config_home(), home / ".config")
|
||||
self.assertEqual(get_xdg_data_home(), home / ".local" / "share")
|
||||
|
||||
def test_whitespace_only_environment_variables(self):
|
||||
"""Test behavior with whitespace-only environment variables."""
|
||||
os.environ["XDG_CONFIG_HOME"] = " "
|
||||
os.environ["XDG_DATA_HOME"] = "\t\n"
|
||||
|
||||
# Should handle whitespace gracefully
|
||||
result_config = get_xdg_config_home()
|
||||
result_data = get_xdg_data_home()
|
||||
|
||||
# Results should be valid Path objects
|
||||
self.assertIsInstance(result_config, Path)
|
||||
self.assertIsInstance(result_data, Path)
|
||||
|
||||
def test_very_long_paths(self):
|
||||
"""Test behavior with very long directory paths."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create a very long path
|
||||
long_path_parts = ["very_long_directory_name_" + str(i) for i in range(20)]
|
||||
long_path = Path(temp_dir)
|
||||
for part in long_path_parts:
|
||||
long_path = long_path / part
|
||||
|
||||
os.environ["XDG_CONFIG_HOME"] = str(long_path)
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, long_path)
|
||||
|
||||
def test_circular_symlinks(self):
|
||||
"""Test handling of circular symbolic links."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Create circular symlinks
|
||||
link1 = temp_path / "link1"
|
||||
link2 = temp_path / "link2"
|
||||
|
||||
try:
|
||||
link1.symlink_to(link2)
|
||||
link2.symlink_to(link1)
|
||||
|
||||
os.environ["XDG_CONFIG_HOME"] = str(link1)
|
||||
|
||||
# Should handle circular symlinks gracefully
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, link1)
|
||||
|
||||
except (OSError, NotImplementedError):
|
||||
# Some systems don't support circular symlinks
|
||||
self.skipTest("System doesn't support circular symlinks")
|
||||
|
||||
def test_permission_denied_scenarios(self):
|
||||
"""Test scenarios where permission is denied."""
|
||||
# This test is platform-specific and may not work on all systems
|
||||
try:
|
||||
# Try to use a system directory that typically requires root
|
||||
os.environ["XDG_CONFIG_HOME"] = "/root/.config"
|
||||
|
||||
# Should still return the path even if we can't access it
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path("/root/.config"))
|
||||
|
||||
except Exception:
|
||||
# If this fails, it's not critical for the XDG implementation
|
||||
pass
|
||||
|
||||
def test_network_paths(self):
|
||||
"""Test XDG paths with network/UNC paths (Windows-style)."""
|
||||
# Test UNC path (though this may not work on non-Windows systems)
|
||||
network_path = "//server/share/config"
|
||||
os.environ["XDG_CONFIG_HOME"] = network_path
|
||||
|
||||
result = get_xdg_config_home()
|
||||
# Should handle network paths gracefully
|
||||
self.assertIsInstance(result, Path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
522
tests/unit/test_xdg_edge_cases.py
Normal file
522
tests/unit/test_xdg_edge_cases.py
Normal file
|
@ -0,0 +1,522 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import platform
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from llama_stack.distribution.utils.xdg_utils import (
|
||||
ensure_directory_exists,
|
||||
get_llama_stack_config_dir,
|
||||
get_xdg_config_home,
|
||||
migrate_legacy_directory,
|
||||
)
|
||||
|
||||
|
||||
class TestXDGEdgeCases(unittest.TestCase):
|
||||
"""Test edge cases and error conditions for XDG compliance."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.original_env = {}
|
||||
for key in ["XDG_CONFIG_HOME", "XDG_DATA_HOME", "XDG_STATE_HOME", "XDG_CACHE_HOME", "LLAMA_STACK_CONFIG_DIR"]:
|
||||
self.original_env[key] = os.environ.get(key)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment."""
|
||||
for key, value in self.original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
def clear_env_vars(self):
|
||||
"""Clear all XDG environment variables."""
|
||||
for key in self.original_env:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_very_long_paths(self):
|
||||
"""Test XDG functions with very long directory paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Create a very long path (close to filesystem limits)
|
||||
long_components = ["very_long_directory_name_" + str(i) for i in range(50)]
|
||||
long_path = "/tmp/" + "/".join(long_components)
|
||||
|
||||
# Test with very long XDG paths
|
||||
os.environ["XDG_CONFIG_HOME"] = long_path
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(long_path))
|
||||
|
||||
# Should handle long paths in llama-stack functions
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
self.assertEqual(config_dir, Path(long_path) / "llama-stack")
|
||||
|
||||
def test_paths_with_special_characters(self):
|
||||
"""Test XDG functions with special characters in paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Test various special characters
|
||||
special_chars = [
|
||||
"path with spaces",
|
||||
"path-with-hyphens",
|
||||
"path_with_underscores",
|
||||
"path.with.dots",
|
||||
"path@with@symbols",
|
||||
"path+with+plus",
|
||||
"path&with&ersand",
|
||||
"path(with)parentheses",
|
||||
]
|
||||
|
||||
for special_path in special_chars:
|
||||
with self.subTest(path=special_path):
|
||||
test_path = f"/tmp/{special_path}"
|
||||
os.environ["XDG_CONFIG_HOME"] = test_path
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(test_path))
|
||||
|
||||
def test_unicode_paths(self):
|
||||
"""Test XDG functions with unicode characters in paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
unicode_paths = [
|
||||
"/配置/llama-stack", # Chinese
|
||||
"/конфигурация/llama-stack", # Russian
|
||||
"/構成/llama-stack", # Japanese
|
||||
"/구성/llama-stack", # Korean
|
||||
"/تكوين/llama-stack", # Arabic
|
||||
"/configuración/llama-stack", # Spanish with accents
|
||||
"/配置📁/llama-stack", # With emoji
|
||||
]
|
||||
|
||||
for unicode_path in unicode_paths:
|
||||
with self.subTest(path=unicode_path):
|
||||
os.environ["XDG_CONFIG_HOME"] = unicode_path
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(unicode_path))
|
||||
|
||||
def test_network_paths(self):
|
||||
"""Test XDG functions with network/UNC paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
if platform.system() == "Windows":
|
||||
# Test Windows UNC paths
|
||||
unc_paths = [
|
||||
"\\\\server\\share\\config",
|
||||
"\\\\server.domain.com\\share\\config",
|
||||
"\\\\192.168.1.100\\config",
|
||||
]
|
||||
|
||||
for unc_path in unc_paths:
|
||||
with self.subTest(path=unc_path):
|
||||
os.environ["XDG_CONFIG_HOME"] = unc_path
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(unc_path))
|
||||
else:
|
||||
# Test network mount paths on Unix-like systems
|
||||
network_paths = [
|
||||
"/mnt/nfs/config",
|
||||
"/net/server/config",
|
||||
"/media/network/config",
|
||||
]
|
||||
|
||||
for network_path in network_paths:
|
||||
with self.subTest(path=network_path):
|
||||
os.environ["XDG_CONFIG_HOME"] = network_path
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(network_path))
|
||||
|
||||
def test_nonexistent_paths(self):
|
||||
"""Test XDG functions with non-existent paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
nonexistent_path = "/this/path/does/not/exist/config"
|
||||
os.environ["XDG_CONFIG_HOME"] = nonexistent_path
|
||||
|
||||
# Should return the path even if it doesn't exist
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(nonexistent_path))
|
||||
|
||||
# Should work with llama-stack functions too
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
self.assertEqual(config_dir, Path(nonexistent_path) / "llama-stack")
|
||||
|
||||
def test_circular_symlinks(self):
|
||||
"""Test XDG functions with circular symbolic links."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Create circular symlinks
|
||||
link1 = temp_path / "link1"
|
||||
link2 = temp_path / "link2"
|
||||
|
||||
try:
|
||||
link1.symlink_to(link2)
|
||||
link2.symlink_to(link1)
|
||||
|
||||
os.environ["XDG_CONFIG_HOME"] = str(link1)
|
||||
|
||||
# Should handle circular symlinks gracefully
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, link1)
|
||||
|
||||
except (OSError, NotImplementedError):
|
||||
# Some systems don't support circular symlinks
|
||||
self.skipTest("System doesn't support circular symlinks")
|
||||
|
||||
def test_broken_symlinks(self):
|
||||
"""Test XDG functions with broken symbolic links."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Create broken symlink
|
||||
target = temp_path / "nonexistent_target"
|
||||
link = temp_path / "broken_link"
|
||||
|
||||
try:
|
||||
link.symlink_to(target)
|
||||
|
||||
os.environ["XDG_CONFIG_HOME"] = str(link)
|
||||
|
||||
# Should handle broken symlinks gracefully
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, link)
|
||||
|
||||
except (OSError, NotImplementedError):
|
||||
# Some systems might not support this
|
||||
self.skipTest("System doesn't support broken symlinks")
|
||||
|
||||
def test_readonly_directories(self):
|
||||
"""Test XDG functions with read-only directories."""
|
||||
self.clear_env_vars()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
readonly_dir = temp_path / "readonly"
|
||||
readonly_dir.mkdir()
|
||||
|
||||
# Make directory read-only
|
||||
readonly_dir.chmod(0o444)
|
||||
|
||||
try:
|
||||
os.environ["XDG_CONFIG_HOME"] = str(readonly_dir)
|
||||
|
||||
# Should still return the path
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, readonly_dir)
|
||||
|
||||
finally:
|
||||
# Restore permissions for cleanup
|
||||
readonly_dir.chmod(0o755)
|
||||
|
||||
def test_permission_denied_access(self):
|
||||
"""Test XDG functions when permission is denied."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# This test is platform-specific
|
||||
if platform.system() != "Windows":
|
||||
# Try to use a system directory that typically requires root
|
||||
restricted_paths = [
|
||||
"/root/.config",
|
||||
"/etc/config",
|
||||
"/var/root/config",
|
||||
]
|
||||
|
||||
for restricted_path in restricted_paths:
|
||||
with self.subTest(path=restricted_path):
|
||||
os.environ["XDG_CONFIG_HOME"] = restricted_path
|
||||
|
||||
# Should still return the path even if we can't access it
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(restricted_path))
|
||||
|
||||
def test_environment_variable_injection(self):
|
||||
"""Test XDG functions with environment variable injection attempts."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Test potential injection attempts
|
||||
injection_attempts = [
|
||||
"/tmp/config; rm -rf /",
|
||||
"/tmp/config && echo 'injected'",
|
||||
"/tmp/config | cat /etc/passwd",
|
||||
"/tmp/config`whoami`",
|
||||
"/tmp/config$(whoami)",
|
||||
"/tmp/config\necho 'newline'",
|
||||
]
|
||||
|
||||
for injection_attempt in injection_attempts:
|
||||
with self.subTest(attempt=injection_attempt):
|
||||
os.environ["XDG_CONFIG_HOME"] = injection_attempt
|
||||
|
||||
# Should treat as literal path, not execute
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(injection_attempt))
|
||||
|
||||
def test_extremely_nested_paths(self):
|
||||
"""Test XDG functions with extremely nested directory structures."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Create deeply nested path
|
||||
nested_parts = ["level" + str(i) for i in range(100)]
|
||||
nested_path = "/tmp/" + "/".join(nested_parts)
|
||||
|
||||
os.environ["XDG_CONFIG_HOME"] = nested_path
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(nested_path))
|
||||
|
||||
def test_empty_and_whitespace_paths(self):
|
||||
"""Test XDG functions with empty and whitespace-only paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
empty_values = [
|
||||
"",
|
||||
" ",
|
||||
"\t",
|
||||
"\n",
|
||||
"\r\n",
|
||||
" \t \n ",
|
||||
]
|
||||
|
||||
for empty_value in empty_values:
|
||||
with self.subTest(value=repr(empty_value)):
|
||||
os.environ["XDG_CONFIG_HOME"] = empty_value
|
||||
|
||||
# Should fall back to default
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path.home() / ".config")
|
||||
|
||||
def test_path_with_null_bytes(self):
|
||||
"""Test XDG functions with null bytes in paths."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Test path with null byte
|
||||
null_path = "/tmp/config\x00/test"
|
||||
os.environ["XDG_CONFIG_HOME"] = null_path
|
||||
|
||||
# Should handle null bytes (Path will likely raise an error, which is expected)
|
||||
try:
|
||||
result = get_xdg_config_home()
|
||||
# If it doesn't raise an error, check the result
|
||||
self.assertIsInstance(result, Path)
|
||||
except (ValueError, OSError):
|
||||
# This is expected behavior for null bytes
|
||||
pass
|
||||
|
||||
def test_concurrent_access_safety(self):
|
||||
"""Test that XDG functions are thread-safe."""
|
||||
import threading
|
||||
import time
|
||||
|
||||
self.clear_env_vars()
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def worker(thread_id):
|
||||
try:
|
||||
# Each thread sets a different XDG path
|
||||
os.environ["XDG_CONFIG_HOME"] = f"/tmp/thread_{thread_id}"
|
||||
|
||||
# Small delay to increase chance of race conditions
|
||||
time.sleep(0.01)
|
||||
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
results.append((thread_id, config_dir))
|
||||
|
||||
except Exception as e:
|
||||
errors.append((thread_id, e))
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for i in range(20):
|
||||
t = threading.Thread(target=worker, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
# Wait for all threads
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Check for errors
|
||||
if errors:
|
||||
self.fail(f"Thread errors: {errors}")
|
||||
|
||||
# Check that we got results from all threads
|
||||
self.assertEqual(len(results), 20)
|
||||
|
||||
def test_filesystem_limits(self):
|
||||
"""Test XDG functions approaching filesystem limits."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Test with very long filename (close to 255 char limit)
|
||||
long_filename = "a" * 240
|
||||
long_path = f"/tmp/{long_filename}"
|
||||
|
||||
os.environ["XDG_CONFIG_HOME"] = long_path
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(long_path))
|
||||
|
||||
def test_case_sensitivity(self):
|
||||
"""Test XDG functions with case sensitivity edge cases."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Test case variations
|
||||
case_variations = [
|
||||
"/tmp/Config",
|
||||
"/tmp/CONFIG",
|
||||
"/tmp/config",
|
||||
"/tmp/Config/MixedCase",
|
||||
]
|
||||
|
||||
for case_path in case_variations:
|
||||
with self.subTest(path=case_path):
|
||||
os.environ["XDG_CONFIG_HOME"] = case_path
|
||||
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(case_path))
|
||||
|
||||
def test_ensure_directory_exists_edge_cases(self):
|
||||
"""Test ensure_directory_exists with edge cases."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Test with file that exists but is not a directory
|
||||
file_path = temp_path / "file_not_dir"
|
||||
file_path.touch()
|
||||
|
||||
with self.assertRaises(FileExistsError):
|
||||
ensure_directory_exists(file_path)
|
||||
|
||||
# Test with permission denied
|
||||
if platform.system() != "Windows":
|
||||
readonly_parent = temp_path / "readonly_parent"
|
||||
readonly_parent.mkdir()
|
||||
readonly_parent.chmod(0o444)
|
||||
|
||||
try:
|
||||
nested_path = readonly_parent / "nested"
|
||||
|
||||
with self.assertRaises(PermissionError):
|
||||
ensure_directory_exists(nested_path)
|
||||
|
||||
finally:
|
||||
# Restore permissions for cleanup
|
||||
readonly_parent.chmod(0o755)
|
||||
|
||||
def test_migrate_legacy_directory_edge_cases(self):
|
||||
"""Test migrate_legacy_directory with edge cases."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
home_dir = Path(temp_dir)
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = home_dir
|
||||
|
||||
# Test with legacy directory but no write permissions
|
||||
legacy_dir = home_dir / ".llama"
|
||||
legacy_dir.mkdir()
|
||||
(legacy_dir / "test_file").touch()
|
||||
|
||||
# Make home directory read-only
|
||||
home_dir.chmod(0o444)
|
||||
|
||||
try:
|
||||
# Should handle permission errors gracefully
|
||||
with patch("builtins.print") as mock_print:
|
||||
migrate_legacy_directory()
|
||||
|
||||
# Should print some information
|
||||
self.assertTrue(mock_print.called)
|
||||
|
||||
finally:
|
||||
# Restore permissions for cleanup
|
||||
home_dir.chmod(0o755)
|
||||
legacy_dir.chmod(0o755)
|
||||
|
||||
def test_path_traversal_attempts(self):
|
||||
"""Test XDG functions with path traversal attempts."""
|
||||
self.clear_env_vars()
|
||||
|
||||
traversal_attempts = [
|
||||
"/tmp/config/../../../etc/passwd",
|
||||
"/tmp/config/../../root/.ssh",
|
||||
"/tmp/config/../../../../../etc/shadow",
|
||||
"/tmp/config/./../../root",
|
||||
]
|
||||
|
||||
for traversal_attempt in traversal_attempts:
|
||||
with self.subTest(attempt=traversal_attempt):
|
||||
os.environ["XDG_CONFIG_HOME"] = traversal_attempt
|
||||
|
||||
# Should handle path traversal attempts by treating as literal paths
|
||||
result = get_xdg_config_home()
|
||||
self.assertEqual(result, Path(traversal_attempt))
|
||||
|
||||
def test_environment_variable_precedence_edge_cases(self):
|
||||
"""Test environment variable precedence with edge cases."""
|
||||
self.clear_env_vars()
|
||||
|
||||
# Test with both old and new env vars set
|
||||
os.environ["LLAMA_STACK_CONFIG_DIR"] = "/legacy/path"
|
||||
os.environ["XDG_CONFIG_HOME"] = "/xdg/path"
|
||||
|
||||
# Create fake legacy directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
fake_home = Path(temp_dir)
|
||||
fake_legacy = fake_home / ".llama"
|
||||
fake_legacy.mkdir()
|
||||
(fake_legacy / "test_file").touch()
|
||||
|
||||
with patch("llama_stack.distribution.utils.xdg_utils.Path.home") as mock_home:
|
||||
mock_home.return_value = fake_home
|
||||
|
||||
# LLAMA_STACK_CONFIG_DIR should take precedence
|
||||
config_dir = get_llama_stack_config_dir()
|
||||
self.assertEqual(config_dir, Path("/legacy/path"))
|
||||
|
||||
def test_malformed_environment_variables(self):
|
||||
"""Test XDG functions with malformed environment variables."""
|
||||
self.clear_env_vars()
|
||||
|
||||
malformed_values = [
|
||||
"not_an_absolute_path",
|
||||
"~/tilde_not_expanded",
|
||||
"$HOME/variable_not_expanded",
|
||||
"relative/path/config",
|
||||
"./relative/path",
|
||||
"../parent/path",
|
||||
]
|
||||
|
||||
for malformed_value in malformed_values:
|
||||
with self.subTest(value=malformed_value):
|
||||
os.environ["XDG_CONFIG_HOME"] = malformed_value
|
||||
|
||||
# Should handle malformed values gracefully
|
||||
result = get_xdg_config_home()
|
||||
self.assertIsInstance(result, Path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue