feat: Add MLflow Prompt Registry provider (squashed commit)

Add a new remote provider that integrates MLflow's Prompt Registry with
Llama Stack's prompts API, enabling centralized prompt management and
versioning using MLflow as the backend.

Features:
- Full implementation of Llama Stack Prompts protocol
- Support for prompt versioning and default version management
- Automatic variable extraction from Jinja2-style templates
- MLflow tag-based metadata for efficient prompt filtering
- Flexible authentication (config, environment variables, per-request)
- Bidirectional ID mapping (pmpt_<hex> ↔ llama_prompt_<hex>)
- Comprehensive error handling and validation

Implementation:
- Remote provider: src/llama_stack/providers/remote/prompts/mlflow/
- Inline reference provider: src/llama_stack/providers/inline/prompts/reference/
- MLflow 3.4+ required for Prompt Registry API support
- Deterministic ID mapping ensures consistency across conversions

Testing:
- 15 comprehensive unit tests (config validation, ID mapping)
- 18 end-to-end integration tests (full CRUD workflows)
- GitHub Actions workflow for automated CI testing with MLflow server
- Integration test fixtures with automatic server setup

Documentation:
- Complete provider configuration reference
- Setup and usage examples with code samples
- Authentication options and security best practices

Signed-off-by: William Caban <william.caban@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
William Caban 2025-11-26 09:21:44 -05:00
parent aac494c5ba
commit 0e0d311dea
24 changed files with 3594 additions and 225 deletions

View file

@ -0,0 +1,125 @@
name: MLflow Prompts Integration Tests
run-name: Run the integration test suite with MLflow Prompt Registry provider
on:
push:
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'src/llama_stack/providers/remote/prompts/mlflow/**'
- 'tests/integration/providers/remote/prompts/mlflow/**'
- 'tests/unit/providers/remote/prompts/mlflow/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/integration-mlflow-tests.yml' # This workflow
schedule:
- cron: '0 0 * * *' # Daily at 12 AM UTC
concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:
test-mlflow:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner
with:
python-version: ${{ matrix.python-version }}
- name: Setup MLflow Server
run: |
docker run --rm -d --pull always \
--name mlflow \
-p 5555:5555 \
ghcr.io/mlflow/mlflow:latest \
mlflow server \
--host 0.0.0.0 \
--port 5555 \
--backend-store-uri sqlite:///mlflow.db \
--default-artifact-root ./mlruns
- name: Wait for MLflow to be ready
run: |
echo "Waiting for MLflow to be ready..."
for i in {1..60}; do
if curl -s http://localhost:5555/health | grep -q '"status": "OK"'; then
echo "MLflow is ready!"
exit 0
fi
echo "Not ready yet... ($i/60)"
sleep 2
done
echo "MLflow failed to start"
docker logs mlflow
exit 1
- name: Verify MLflow API
run: |
echo "Testing MLflow API..."
curl -X GET http://localhost:5555/api/2.0/mlflow/experiments/list
echo ""
echo "MLflow API is responding!"
- name: Build Llama Stack
run: |
uv run --no-sync llama stack list-deps ci-tests | xargs -L1 uv pip install
- name: Install MLflow Python client
run: |
uv pip install 'mlflow>=3.4.0'
- name: Check Storage and Memory Available Before Tests
if: ${{ always() }}
run: |
free -h
df -h
- name: Run MLflow Integration Tests
env:
MLFLOW_TRACKING_URI: http://localhost:5555
run: |
uv run --no-sync \
pytest -sv \
tests/integration/providers/remote/prompts/mlflow/
- name: Check Storage and Memory Available After Tests
if: ${{ always() }}
run: |
free -h
df -h
- name: Write MLflow logs to file
if: ${{ always() }}
run: |
docker logs mlflow > mlflow.log 2>&1 || true
- name: Upload all logs to artifacts
if: ${{ always() }}
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
with:
name: mlflow-logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.python-version }}
path: |
*.log
retention-days: 1
- name: Stop MLflow container
if: ${{ always() }}
run: |
docker stop mlflow || true

View file

@ -0,0 +1,92 @@
---
sidebar_label: Prompts
title: Prompts
---
# Prompts
## Overview
This section contains documentation for all available providers for the **prompts** API.
The Prompts API enables centralized management of prompt templates with versioning, variable handling, and team collaboration capabilities.
## Available Providers
### Inline Providers
Inline providers run in the same process as the Llama Stack server and require no external dependencies:
- **[inline::reference](inline_reference.mdx)** - Reference implementation using KVStore backend (SQLite, PostgreSQL, etc.)
- Zero external dependencies
- Supports local SQLite or PostgreSQL storage
- Full CRUD operations including deletion
- Ideal for local development and single-server deployments
### Remote Providers
Remote providers connect to external services for centralized prompt management:
- **[remote::mlflow](remote_mlflow.mdx)** - MLflow Prompt Registry integration (requires MLflow 3.4+)
- Centralized prompt management across teams
- Built-in versioning and audit trail
- Supports authentication (per-request, config, or environment variables)
- Integrates with Databricks and enterprise MLflow deployments
- Ideal for team collaboration and production environments
## Choosing a Provider
### Use `inline::reference` when:
- Developing locally or deploying to a single server
- You want zero external dependencies
- SQLite or PostgreSQL storage is sufficient
- You need full CRUD operations (including deletion)
- You prefer simple configuration
### Use `remote::mlflow` when:
- Working in a team environment with multiple users
- You need centralized prompt management
- Integration with existing MLflow infrastructure
- You need authentication and multi-tenant support
- Advanced versioning and audit trail capabilities are required
## Quick Start Examples
### Using inline::reference
```yaml
prompts:
- provider_id: local-prompts
provider_type: inline::reference
config:
run_config:
storage:
stores:
prompts:
type: sqlite
db_path: ./prompts.db
```
### Using remote::mlflow
```yaml
prompts:
- provider_id: mlflow-prompts
provider_type: remote::mlflow
config:
mlflow_tracking_uri: http://localhost:5555
experiment_name: llama-stack-prompts
auth_credential: ${env.MLFLOW_TRACKING_TOKEN}
```
## Common Features
All prompt providers support:
- Create and store prompts with version control
- Retrieve prompts by ID and version
- Update prompts (creates new versions)
- List all prompts or versions of a specific prompt
- Set default version for a prompt
- Automatic variable extraction from `{{ variable }}` templates
For detailed documentation on each provider, see the individual provider pages linked above.

View file

@ -0,0 +1,496 @@
---
description: |
Reference implementation of the Prompts API using KVStore backend (SQLite, PostgreSQL, etc.)
for centralized prompt management with versioning support. This is the default provider for
prompts that works without external dependencies.
## Features
The Reference Prompts Provider supports:
- Create and store prompts with automatic versioning
- Retrieve prompts by ID and version
- Update prompts (creates new immutable versions)
- Delete prompts and their versions
- List all prompts or all versions of a specific prompt
- Set default version for a prompt
- Automatic variable extraction from templates
- Storage in SQLite, PostgreSQL, or other KVStore backends
## Key Capabilities
- **Zero Dependencies**: No external services required, runs in-process
- **Flexible Storage**: Supports SQLite (default), PostgreSQL, and other KVStore backends
- **Version Control**: Immutable versioning ensures prompt history is preserved
- **Default Version Management**: Easily switch between prompt versions
- **Variable Auto-Extraction**: Automatically detects `{{ variable }}` placeholders
- **Full CRUD Support**: Unlike remote providers, supports deletion of prompts
## Usage
To use Reference Prompts Provider in your Llama Stack project:
1. Configure your Llama Stack project with the inline::reference provider
2. Optionally configure storage backend (defaults to SQLite)
3. Start creating and managing prompts
## Quick Start
### 1. Configure Llama Stack
**Basic configuration with SQLite** (default):
```yaml
prompts:
- provider_id: reference-prompts
provider_type: inline::reference
config:
run_config:
storage:
stores:
prompts:
type: sqlite
db_path: ./prompts.db
```
**With PostgreSQL**:
```yaml
prompts:
- provider_id: postgres-prompts
provider_type: inline::reference
config:
run_config:
storage:
stores:
prompts:
type: postgres
url: postgresql://user:pass@localhost/llama_stack
```
### 2. Use the Prompts API
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:5000")
# Create a prompt
prompt = client.prompts.create(
prompt="Summarize the following text in {{ num_sentences }} sentences:\n\n{{ text }}",
variables=["num_sentences", "text"]
)
print(f"Created prompt: {prompt.prompt_id} (v{prompt.version})")
# Retrieve prompt
retrieved = client.prompts.get(prompt_id=prompt.prompt_id)
print(f"Retrieved: {retrieved.prompt}")
# Update prompt (creates version 2)
updated = client.prompts.update(
prompt_id=prompt.prompt_id,
prompt="Summarize in exactly {{ num_sentences }} sentences:\n\n{{ text }}",
version=1,
set_as_default=True
)
print(f"Updated to version: {updated.version}")
# List all prompts
prompts = client.prompts.list()
print(f"Found {len(prompts.data)} prompts")
# Delete prompt
client.prompts.delete(prompt_id=prompt.prompt_id)
```
sidebar_label: Inline - Reference
title: inline::reference
---
# inline::reference
## Description
Reference implementation of the Prompts API using KVStore backend (SQLite, PostgreSQL, etc.)
for centralized prompt management with versioning support. This is the default provider for
prompts that works without external dependencies.
## Features
The Reference Prompts Provider supports:
- Create and store prompts with automatic versioning
- Retrieve prompts by ID and version
- Update prompts (creates new immutable versions)
- Delete prompts and their versions
- List all prompts or all versions of a specific prompt
- Set default version for a prompt
- Automatic variable extraction from templates
- Storage in SQLite, PostgreSQL, or other KVStore backends
## Key Capabilities
- **Zero Dependencies**: No external services required, runs in-process
- **Flexible Storage**: Supports SQLite (default), PostgreSQL, and other KVStore backends
- **Version Control**: Immutable versioning ensures prompt history is preserved
- **Default Version Management**: Easily switch between prompt versions
- **Variable Auto-Extraction**: Automatically detects `{{ variable }}` placeholders
- **Full CRUD Support**: Unlike remote providers, supports deletion of prompts
## Configuration Examples
### SQLite (Local Development)
For local development with filesystem storage:
```yaml
prompts:
- provider_id: local-prompts
provider_type: inline::reference
config:
run_config:
storage:
stores:
prompts:
type: sqlite
db_path: ./prompts.db
```
### PostgreSQL (Production)
For production with PostgreSQL:
```yaml
prompts:
- provider_id: prod-prompts
provider_type: inline::reference
config:
run_config:
storage:
stores:
prompts:
type: postgres
url: ${env.DATABASE_URL}
```
### With Explicit Backend Configuration
```yaml
prompts:
- provider_id: reference-prompts
provider_type: inline::reference
config:
run_config:
storage:
backends:
kv_default:
type: sqlite
db_path: ./data/prompts.db
stores:
prompts:
backend: kv_default
namespace: prompts
```
## API Reference
### Create Prompt
Creates a new prompt (version 1):
```python
prompt = client.prompts.create(
prompt="You are a {{ role }} assistant. {{ instruction }}",
variables=["role", "instruction"] # Optional - auto-extracted if omitted
)
```
**Auto-extraction**: If `variables` is not provided, the provider automatically extracts variables from `{{ variable }}` placeholders.
### Retrieve Prompt
Get a prompt by ID (retrieves default version):
```python
prompt = client.prompts.get(prompt_id="pmpt_abc123...")
```
Get a specific version:
```python
prompt = client.prompts.get(prompt_id="pmpt_abc123...", version=2)
```
### Update Prompt
Creates a new version of an existing prompt:
```python
updated = client.prompts.update(
prompt_id="pmpt_abc123...",
prompt="Updated template with {{ variable }}",
version=1, # Must be the latest version
set_as_default=True # Make this the new default
)
```
**Important**: You must provide the current latest version number. The update creates a new version (e.g., version 2).
### Delete Prompt
Delete a prompt and all its versions:
```python
client.prompts.delete(prompt_id="pmpt_abc123...")
```
**Note**: This operation is permanent and deletes all versions of the prompt.
### List Prompts
List all prompts (returns default versions only):
```python
response = client.prompts.list()
for prompt in response.data:
print(f"{prompt.prompt_id}: v{prompt.version} (default)")
```
### List Prompt Versions
List all versions of a specific prompt:
```python
response = client.prompts.list_versions(prompt_id="pmpt_abc123...")
for prompt in response.data:
default = " (default)" if prompt.is_default else ""
print(f"Version {prompt.version}{default}")
```
### Set Default Version
Change which version is the default:
```python
client.prompts.set_default_version(
prompt_id="pmpt_abc123...",
version=2
)
```
## Version Management
The Reference Prompts Provider implements immutable versioning:
1. **Create**: Creates version 1
2. **Update**: Creates a new version (2, 3, 4, ...)
3. **Default**: One version is marked as default
4. **History**: All versions are preserved and retrievable
5. **Delete**: Can delete all versions at once
```
pmpt_abc123
├── Version 1 (Original)
├── Version 2 (Updated)
└── Version 3 (Latest, Default) <- Current default version
```
## Storage Backends
The reference provider uses Llama Stack's KVStore abstraction, which supports multiple backends:
### SQLite (Default)
Best for:
- Local development
- Single-server deployments
- Embedded applications
- Testing
Limitations:
- Not suitable for high-concurrency scenarios
- No built-in replication
### PostgreSQL
Best for:
- Production deployments
- Multi-server setups
- High availability requirements
- Team collaboration
Advantages:
- Supports concurrent access
- Built-in replication and backups
- Scalable and robust
## Best Practices
### 1. Choose Appropriate Storage
**Development**:
```yaml
# Use SQLite for local development
storage:
stores:
prompts:
type: sqlite
db_path: ./dev-prompts.db
```
**Production**:
```yaml
# Use PostgreSQL for production
storage:
stores:
prompts:
type: postgres
url: ${env.DATABASE_URL}
```
### 2. Backup Your Data
For SQLite:
```bash
# Backup SQLite database
cp prompts.db prompts.db.backup
```
For PostgreSQL:
```bash
# Backup PostgreSQL database
pg_dump llama_stack > backup.sql
```
### 3. Version Management
- Always retrieve latest version before updating
- Use `set_as_default=True` when updating to make new version active
- Keep version history for audit trail
- Use deletion sparingly (consider archiving instead)
### 4. Auto-Extract Variables
Let the provider auto-extract variables to avoid validation errors:
```python
# Recommended
prompt = client.prompts.create(
prompt="Summarize {{ text }} in {{ format }}"
)
```
### 5. Use Meaningful Templates
Include context in your templates:
```python
# Good
prompt = """You are a {{ role }} assistant specialized in {{ domain }}.
Task: {{ task }}
Output format: {{ format }}"""
# Less clear
prompt = "Do {{ task }} as {{ role }}"
```
## Troubleshooting
### Database Connection Errors
**Error**: Failed to connect to database
**Solutions**:
1. Verify database URL is correct
2. Ensure database server is running (for PostgreSQL)
3. Check file permissions (for SQLite)
4. Verify network connectivity (for remote databases)
### Version Mismatch Error
**Error**: `Version X is not the latest version. Use latest version Y to update.`
**Cause**: Attempting to update an outdated version
**Solution**: Always use the latest version number when updating:
```python
# Get latest version
versions = client.prompts.list_versions(prompt_id)
latest_version = max(v.version for v in versions.data)
# Use latest version for update
client.prompts.update(prompt_id=prompt_id, version=latest_version, ...)
```
### Variable Validation Error
**Error**: `Template contains undeclared variables: ['var2']`
**Cause**: Template has `{{ var2 }}` but `variables` list doesn't include it
**Solution**: Either add missing variable or let the provider auto-extract:
```python
# Option 1: Add missing variable
client.prompts.create(
prompt="Template with {{ var1 }} and {{ var2 }}",
variables=["var1", "var2"]
)
# Option 2: Let provider auto-extract (recommended)
client.prompts.create(
prompt="Template with {{ var1 }} and {{ var2 }}"
)
```
### Prompt Not Found
**Error**: `Prompt pmpt_abc123... not found`
**Possible causes**:
1. Prompt ID is incorrect
2. Prompt was deleted
3. Wrong database or storage backend
**Solution**: Verify prompt exists using `list()` method
## Migration Guide
### Migrating from Core Implementation
If you're upgrading from an older Llama Stack version where prompts were in `core/prompts`:
**Old code** (still works):
```python
from llama_stack.core.prompts import PromptServiceConfig, PromptServiceImpl
```
**New code** (recommended):
```python
from llama_stack.providers.inline.prompts.reference import ReferencePromptsConfig, PromptServiceImpl
```
**Note**: Backward compatibility is maintained. Old imports still work.
### Data Migration
No data migration needed when upgrading:
- Same KVStore backend is used
- Existing prompts remain accessible
- Configuration structure is compatible
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `run_config` | `StackRunConfig` | Yes | | Stack run configuration containing storage configuration for KVStore |
## Sample Configuration
```yaml
run_config:
storage:
backends:
kv_default:
type: sqlite
db_path: ./prompts.db
stores:
prompts:
backend: kv_default
namespace: prompts
```

View file

@ -0,0 +1,751 @@
---
description: |
[MLflow](https://mlflow.org/) is a remote provider for centralized prompt management and versioning
using MLflow's Prompt Registry (available in MLflow 3.4+). It allows you to store, version, and manage
prompts in a centralized MLflow server, enabling team collaboration and prompt lifecycle management.
See [MLflow's documentation](https://mlflow.org/docs/latest/prompts.html) for more details about MLflow Prompt Registry.
sidebar_label: Remote - MLflow
title: remote::mlflow
---
# remote::mlflow
## Description
[MLflow](https://mlflow.org/) is a remote provider for centralized prompt management and versioning
using MLflow's Prompt Registry (available in MLflow 3.4+). It allows you to store, version, and manage
prompts in a centralized MLflow server, enabling team collaboration and prompt lifecycle management.
## Features
MLflow Prompts Provider supports:
- Create and store prompts with automatic versioning
- Retrieve prompts by ID and version
- Update prompts (creates new immutable versions)
- List all prompts or all versions of a specific prompt
- Set default version for a prompt
- Automatic variable extraction from templates
- Metadata storage and retrieval
- Centralized prompt management across teams
## Key Capabilities
- **Version Control**: Immutable versioning ensures prompt history is preserved
- **Default Version Management**: Easily switch between prompt versions
- **Variable Auto-Extraction**: Automatically detects `{{ variable }}` placeholders
- **Metadata Tags**: Stores Llama Stack metadata for seamless integration
- **Team Collaboration**: Centralized MLflow server enables multi-user access
## Usage
To use MLflow Prompts Provider in your Llama Stack project:
1. Install MLflow 3.4 or later
2. Start an MLflow server (local or remote)
3. Configure your Llama Stack project to use the MLflow provider
4. Start creating and managing prompts
## Installation
Install MLflow using pip or uv:
```bash
pip install 'mlflow>=3.4.0'
# or
uv pip install 'mlflow>=3.4.0'
```
## Quick Start
### 1. Start MLflow Server
**Local server** (for development):
```bash
mlflow server --host 127.0.0.1 --port 5555
```
**Remote server** (for production):
```bash
mlflow server --host 0.0.0.0 --port 5000 --backend-store-uri postgresql://user:pass@host/db
```
### 2. Configure Llama Stack
Add to your Llama Stack configuration:
```yaml
prompts:
- provider_id: mlflow-prompts
provider_type: remote::mlflow
config:
mlflow_tracking_uri: http://localhost:5555
experiment_name: llama-stack-prompts
```
### 3. Use the Prompts API
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:5000")
# Create a prompt
prompt = client.prompts.create(
prompt="Summarize the following text in {{ num_sentences }} sentences:\n\n{{ text }}",
variables=["num_sentences", "text"]
)
print(f"Created prompt: {prompt.prompt_id} (v{prompt.version})")
# Retrieve prompt
retrieved = client.prompts.get(prompt_id=prompt.prompt_id)
print(f"Retrieved: {retrieved.prompt}")
# Update prompt (creates version 2)
updated = client.prompts.update(
prompt_id=prompt.prompt_id,
prompt="Summarize in exactly {{ num_sentences }} sentences:\n\n{{ text }}",
version=1,
set_as_default=True
)
print(f"Updated to version: {updated.version}")
# List all prompts
prompts = client.prompts.list()
print(f"Found {len(prompts.data)} prompts")
```
## Configuration Examples
### Local Development
For local development with filesystem storage:
```yaml
prompts:
- provider_id: mlflow-local
provider_type: remote::mlflow
config:
mlflow_tracking_uri: http://localhost:5555
experiment_name: dev-prompts
timeout_seconds: 30
```
### Remote MLflow Server
For production with a remote MLflow server:
```yaml
prompts:
- provider_id: mlflow-production
provider_type: remote::mlflow
config:
mlflow_tracking_uri: ${env.MLFLOW_TRACKING_URI}
experiment_name: production-prompts
timeout_seconds: 60
```
### Advanced Configuration
With custom settings:
```yaml
prompts:
- provider_id: mlflow-custom
provider_type: remote::mlflow
config:
mlflow_tracking_uri: https://mlflow.example.com
experiment_name: team-prompts
timeout_seconds: 45
```
## Authentication
The MLflow provider supports three authentication methods with the following precedence (highest to lowest):
1. **Per-Request Provider Data** (via headers)
2. **Configuration Auth Credential** (in config file)
3. **Environment Variables** (MLflow defaults)
### Method 1: Per-Request Provider Data (Recommended for Multi-Tenant)
For multi-tenant deployments where each user has their own credentials:
**Configuration**:
```yaml
prompts:
- provider_id: mlflow-prompts
provider_type: remote::mlflow
config:
mlflow_tracking_uri: http://mlflow.company.com
experiment_name: production-prompts
# No auth_credential - use per-request tokens
```
**Client Usage**:
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:5000")
# User 1 with their own token
prompts_user1 = client.prompts.list(
extra_headers={
"x-llamastack-provider-data": '{"mlflow_api_token": "user1-token"}'
}
)
# User 2 with their own token
prompts_user2 = client.prompts.list(
extra_headers={
"x-llamastack-provider-data": '{"mlflow_api_token": "user2-token"}'
}
)
```
**Benefits**:
- Per-user authentication and authorization
- No shared credentials
- Ideal for SaaS deployments
- Supports user-specific MLflow experiments
### Method 2: Configuration Auth Credential (Server-Level)
For server-level authentication where all requests use the same credentials:
**Using Environment Variable** (recommended):
```yaml
prompts:
- provider_id: mlflow-prompts
provider_type: remote::mlflow
config:
mlflow_tracking_uri: http://mlflow.company.com
experiment_name: production-prompts
auth_credential: ${env.MLFLOW_TRACKING_TOKEN}
```
**Using Direct Value** (not recommended for production):
```yaml
prompts:
- provider_id: mlflow-prompts
provider_type: remote::mlflow
config:
mlflow_tracking_uri: http://mlflow.company.com
experiment_name: production-prompts
auth_credential: "mlflow-server-token"
```
**Client Usage**:
```python
# No extra headers needed - server handles authentication
client = LlamaStackClient(base_url="http://localhost:5000")
prompts = client.prompts.list()
```
**Benefits**:
- Simple configuration
- Single point of credential management
- Good for single-tenant deployments
### Method 3: Environment Variables (MLflow Default)
MLflow reads standard environment variables automatically:
**Set before starting Llama Stack**:
```bash
export MLFLOW_TRACKING_TOKEN="your-token"
export MLFLOW_TRACKING_USERNAME="user" # Optional: Basic auth
export MLFLOW_TRACKING_PASSWORD="pass" # Optional: Basic auth
llama stack run my-config.yaml
```
**Configuration** (no auth_credential needed):
```yaml
prompts:
- provider_id: mlflow-prompts
provider_type: remote::mlflow
config:
mlflow_tracking_uri: http://mlflow.company.com
experiment_name: production-prompts
```
**Benefits**:
- Standard MLflow behavior
- No configuration changes needed
- Good for containerized deployments
### Databricks Authentication
For Databricks-managed MLflow:
**Configuration**:
```yaml
prompts:
- provider_id: databricks-prompts
provider_type: remote::mlflow
config:
mlflow_tracking_uri: databricks
# Or with workspace URL:
# mlflow_tracking_uri: databricks://profile-name
experiment_name: /Shared/llama-stack-prompts
auth_credential: ${env.DATABRICKS_TOKEN}
```
**Environment Setup**:
```bash
export DATABRICKS_TOKEN="dapi..."
export DATABRICKS_HOST="https://your-workspace.cloud.databricks.com"
```
**Client Usage**:
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:5000")
# Create prompt in Databricks MLflow
prompt = client.prompts.create(
prompt="Analyze {{ topic }} with focus on {{ aspect }}",
variables=["topic", "aspect"]
)
# View in Databricks UI:
# https://workspace.cloud.databricks.com/#mlflow/experiments/<experiment-id>
```
### Enterprise MLflow with Authentication
Example for enterprise MLflow server with API key authentication:
**Configuration**:
```yaml
prompts:
- provider_id: enterprise-mlflow
provider_type: remote::mlflow
config:
mlflow_tracking_uri: https://mlflow.enterprise.com
experiment_name: production-prompts
auth_credential: ${env.MLFLOW_API_KEY}
timeout_seconds: 60
```
**Client Usage**:
```python
from llama_stack_client import LlamaStackClient
# Option A: Use server's configured credential
client = LlamaStackClient(base_url="http://localhost:5000")
prompt = client.prompts.create(
prompt="Classify sentiment: {{ text }}",
variables=["text"]
)
# Option B: Override with per-request credential
prompt = client.prompts.create(
prompt="Classify sentiment: {{ text }}",
variables=["text"],
extra_headers={
"x-llamastack-provider-data": '{"mlflow_api_token": "user-specific-key"}'
}
)
```
### Authentication Precedence
When multiple authentication methods are configured, the provider uses this precedence:
1. **Per-request provider data** (from `x-llamastack-provider-data` header)
- Highest priority
- Overrides all other methods
- Used for multi-tenant scenarios
2. **Configuration auth_credential** (from config file)
- Medium priority
- Fallback if no provider data header
- Good for server-level auth
3. **Environment variables** (MLflow standard)
- Lowest priority
- Used if no other credentials provided
- Standard MLflow behavior
**Example showing precedence**:
```yaml
# Config file
prompts:
- provider_id: mlflow
provider_type: remote::mlflow
config:
mlflow_tracking_uri: http://mlflow.company.com
auth_credential: ${env.MLFLOW_TRACKING_TOKEN} # Fallback
```
```bash
# Environment variable
export MLFLOW_TRACKING_TOKEN="server-token" # Lowest priority
```
```python
# Client code
client.prompts.create(
prompt="Test",
extra_headers={
# This takes precedence over config and env vars
"x-llamastack-provider-data": '{"mlflow_api_token": "user-token"}'
}
)
```
### Security Best Practices
1. **Never hardcode tokens** in configuration files:
```yaml
# Bad - hardcoded credential
auth_credential: "my-secret-token"
# Good - use environment variable
auth_credential: ${env.MLFLOW_TRACKING_TOKEN}
```
2. **Use per-request credentials** for multi-tenant deployments:
```python
# Good - each user provides their own token
headers = {
"x-llamastack-provider-data": f'{{"mlflow_api_token": "{user_token}"}}'
}
client.prompts.list(extra_headers=headers)
```
3. **Rotate credentials regularly** in production environments
4. **Use HTTPS** for MLflow tracking URI in production:
```yaml
mlflow_tracking_uri: https://mlflow.company.com # Good
# Not: http://mlflow.company.com # Bad for production
```
5. **Store secrets in secure vaults** (AWS Secrets Manager, HashiCorp Vault, etc.)
## API Reference
### Create Prompt
Creates a new prompt (version 1) or registers a prompt in MLflow:
```python
prompt = client.prompts.create(
prompt="You are a {{ role }} assistant. {{ instruction }}",
variables=["role", "instruction"] # Optional - auto-extracted if omitted
)
```
**Auto-extraction**: If `variables` is not provided, the provider automatically extracts variables from `{{ variable }}` placeholders.
### Retrieve Prompt
Get a prompt by ID (retrieves default version):
```python
prompt = client.prompts.get(prompt_id="pmpt_abc123...")
```
Get a specific version:
```python
prompt = client.prompts.get(prompt_id="pmpt_abc123...", version=2)
```
### Update Prompt
Creates a new version of an existing prompt:
```python
updated = client.prompts.update(
prompt_id="pmpt_abc123...",
prompt="Updated template with {{ variable }}",
version=1, # Must be the latest version
set_as_default=True # Make this the new default
)
```
**Important**: You must provide the current latest version number. The update creates a new version (e.g., version 2).
### List Prompts
List all prompts (returns default versions only):
```python
response = client.prompts.list()
for prompt in response.data:
print(f"{prompt.prompt_id}: v{prompt.version} (default)")
```
### List Prompt Versions
List all versions of a specific prompt:
```python
response = client.prompts.list_versions(prompt_id="pmpt_abc123...")
for prompt in response.data:
default = " (default)" if prompt.is_default else ""
print(f"Version {prompt.version}{default}")
```
### Set Default Version
Change which version is the default:
```python
client.prompts.set_default_version(
prompt_id="pmpt_abc123...",
version=2
)
```
## ID Mapping
The MLflow provider uses deterministic bidirectional ID mapping:
- **Llama Stack format**: `pmpt_<48-hex-chars>`
- **MLflow format**: `llama_prompt_<48-hex-chars>`
Example:
- Llama Stack ID: `pmpt_8c2bf57972a215cd0413e399d03b901cce93815448173c1c`
- MLflow name: `llama_prompt_8c2bf57972a215cd0413e399d03b901cce93815448173c1c`
This ensures prompts created through Llama Stack are easily identifiable in MLflow.
## Version Management
MLflow Prompts Provider implements immutable versioning:
1. **Create**: Creates version 1
2. **Update**: Creates a new version (2, 3, 4, ...)
3. **Default**: The "default" alias points to the current default version
4. **History**: All versions are preserved and retrievable
```
pmpt_abc123
├── Version 1 (Original)
├── Version 2 (Updated)
└── Version 3 (Latest, Default) ← Default alias points here
```
## Troubleshooting
### MLflow Server Not Available
**Error**: `Failed to connect to MLflow server`
**Solutions**:
1. Verify MLflow server is running: `curl http://localhost:5555/health`
2. Check `mlflow_tracking_uri` in configuration
3. Ensure network connectivity to remote server
4. Check firewall settings
### Version Mismatch Error
**Error**: `Version X is not the latest version. Use latest version Y to update.`
**Cause**: Attempting to update an outdated version
**Solution**: Always use the latest version number when updating:
```python
# Get latest version
versions = client.prompts.list_versions(prompt_id)
latest_version = max(v.version for v in versions.data)
# Use latest version for update
client.prompts.update(prompt_id=prompt_id, version=latest_version, ...)
```
### Variable Validation Error
**Error**: `Template contains undeclared variables: ['var2']`
**Cause**: Template has `{{ var2 }}` but `variables` list doesn't include it
**Solution**: Either add missing variable or let the provider auto-extract:
```python
# Option 1: Add missing variable
client.prompts.create(
prompt="Template with {{ var1 }} and {{ var2 }}",
variables=["var1", "var2"]
)
# Option 2: Let provider auto-extract (recommended)
client.prompts.create(
prompt="Template with {{ var1 }} and {{ var2 }}"
)
```
### Timeout Errors
**Error**: Connection timeout when communicating with MLflow
**Solutions**:
1. Increase `timeout_seconds` in configuration:
```yaml
config:
timeout_seconds: 60 # Default: 30
```
2. Check network latency to MLflow server
3. Verify MLflow server is responsive
### Prompt Not Found
**Error**: `Prompt pmpt_abc123... not found`
**Possible causes**:
1. Prompt ID is incorrect
2. Prompt was created in a different MLflow server/experiment
3. Experiment name mismatch in configuration
**Solution**: Verify prompt exists in MLflow UI at `http://localhost:5555`
## Limitations
### No Deletion Support
**MLflow does not support deleting prompts or versions**. The `delete_prompt()` method raises `NotImplementedError`.
**Workaround**: Mark prompts as deprecated using naming conventions or set a different version as default.
### Experiment Required
All prompts are stored within an MLflow experiment. The experiment is created automatically if it doesn't exist.
### ID Format Constraints
- Prompt IDs must follow the format: `pmpt_<48-hex-chars>`
- MLflow names use the prefix: `llama_prompt_`
- Manual creation in MLflow with different names won't be recognized
### Version Numbering
- Versions are sequential integers (1, 2, 3, ...)
- You cannot skip version numbers
- You cannot manually set version numbers
## Best Practices
### 1. Use Environment Variables
Store MLflow URIs in environment variables:
```yaml
config:
mlflow_tracking_uri: ${env.MLFLOW_TRACKING_URI:=http://localhost:5555}
```
### 2. Auto-Extract Variables
Let the provider auto-extract variables to avoid validation errors:
```python
# Recommended
prompt = client.prompts.create(
prompt="Summarize {{ text }} in {{ format }}"
)
```
### 3. Organize by Experiment
Use different experiments for different environments:
- `dev-prompts` for development
- `staging-prompts` for staging
- `production-prompts` for production
### 4. Version Management
- Always retrieve latest version before updating
- Use `set_as_default=True` when updating to make new version active
- Keep version history for audit trail
### 5. Use Meaningful Templates
Include context in your templates:
```python
# Good
prompt = """You are a {{ role }} assistant specialized in {{ domain }}.
Task: {{ task }}
Output format: {{ format }}"""
# Less clear
prompt = "Do {{ task }} as {{ role }}"
```
### 6. Monitor MLflow Server
- Use MLflow UI to visualize prompts: `http://your-server:5555`
- Monitor experiment metrics and prompt versions
- Set up alerts for MLflow server health
## Production Deployment
### Database Backend
For production, use a database backend instead of filesystem:
```bash
mlflow server \
--host 0.0.0.0 \
--port 5000 \
--backend-store-uri postgresql://user:pass@host:5432/mlflow \
--default-artifact-root s3://my-bucket/mlflow-artifacts
```
### High Availability
- Deploy multiple MLflow server instances behind a load balancer
- Use managed database (RDS, Cloud SQL, etc.)
- Store artifacts in object storage (S3, GCS, Azure Blob)
### Security
- Enable authentication on MLflow server
- Use HTTPS for MLflow tracking URI
- Restrict network access with firewall rules
- Use IAM roles for cloud deployments
### Monitoring
Set up monitoring for:
- MLflow server availability
- Database connection pool
- API response times
- Prompt creation/retrieval rates
## Documentation
See [MLflow's documentation](https://mlflow.org/docs/latest/prompts.html) for more details about MLflow Prompt Registry.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `mlflow_tracking_uri` | `str` | No | http://localhost:5000 | MLflow tracking server URI |
| `mlflow_registry_uri` | `str \| None` | No | None | MLflow model registry URI (defaults to tracking URI if not set) |
| `experiment_name` | `str` | No | llama-stack-prompts | MLflow experiment name for storing prompts |
| `auth_credential` | `SecretStr \| None` | No | None | MLflow API token for authentication. Can be overridden via provider data header. |
| `timeout_seconds` | `int` | No | 30 | Timeout for MLflow API calls (1-300 seconds) |
## Sample Configuration
**Without authentication** (local development):
```yaml
mlflow_tracking_uri: http://localhost:5555
experiment_name: llama-stack-prompts
timeout_seconds: 30
```
**With authentication** (production):
```yaml
mlflow_tracking_uri: ${env.MLFLOW_TRACKING_URI:=http://localhost:5000}
experiment_name: llama-stack-prompts
auth_credential: ${env.MLFLOW_TRACKING_TOKEN:=}
timeout_seconds: 30
```

View file

@ -4,232 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json """Core prompts service delegating to inline::reference provider.
from typing import Any
from pydantic import BaseModel This module provides backward compatibility by delegating to the
inline::reference provider implementation.
"""
from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.inline.prompts.reference import (
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl PromptServiceImpl,
from llama_stack_api import ListPromptsResponse, Prompt, Prompts ReferencePromptsConfig,
get_adapter_impl,
)
# Re-export for backward compatibility
PromptServiceConfig = ReferencePromptsConfig
get_provider_impl = get_adapter_impl
class PromptServiceConfig(BaseModel): __all__ = [
"""Configuration for the built-in prompt service. "PromptServiceImpl",
"PromptServiceConfig",
:param run_config: Stack run configuration containing distribution info "ReferencePromptsConfig",
""" "get_provider_impl",
"get_adapter_impl",
run_config: StackRunConfig ]
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
"""Get the prompt service implementation."""
impl = PromptServiceImpl(config, deps)
await impl.initialize()
return impl
class PromptServiceImpl(Prompts):
"""Built-in prompt service implementation using KVStore."""
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.kvstore: KVStore
async def initialize(self) -> None:
# Use prompts store reference from run config
prompts_ref = self.config.run_config.storage.stores.prompts
if not prompts_ref:
raise ValueError("storage.stores.prompts must be configured in run config")
self.kvstore = await kvstore_impl(prompts_ref)
def _get_default_key(self, prompt_id: str) -> str:
"""Get the KVStore key that stores the default version number."""
return f"prompts:v1:{prompt_id}:default"
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
"""Get the KVStore key for prompt data, returning default version if applicable."""
if version:
return self._get_version_key(prompt_id, str(version))
default_key = self._get_default_key(prompt_id)
resolved_version = await self.kvstore.get(default_key)
if resolved_version is None:
raise ValueError(f"Prompt {prompt_id}:default not found")
return self._get_version_key(prompt_id, resolved_version)
def _get_version_key(self, prompt_id: str, version: str) -> str:
"""Get the KVStore key for a specific prompt version."""
return f"prompts:v1:{prompt_id}:{version}"
def _get_list_key_prefix(self) -> str:
"""Get the key prefix for listing prompts."""
return "prompts:v1:"
def _serialize_prompt(self, prompt: Prompt) -> str:
"""Serialize a prompt to JSON string for storage."""
return json.dumps(
{
"prompt_id": prompt.prompt_id,
"prompt": prompt.prompt,
"version": prompt.version,
"variables": prompt.variables or [],
"is_default": prompt.is_default,
}
)
def _deserialize_prompt(self, data: str) -> Prompt:
"""Deserialize a prompt from JSON string."""
obj = json.loads(data)
return Prompt(
prompt_id=obj["prompt_id"],
prompt=obj["prompt"],
version=obj["version"],
variables=obj.get("variables", []),
is_default=obj.get("is_default", False),
)
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts (default versions only)."""
prefix = self._get_list_key_prefix()
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
prompts = []
for key in keys:
if key.endswith(":default"):
try:
default_version = await self.kvstore.get(key)
if default_version:
prompt_id = key.replace(prefix, "").replace(":default", "")
version_key = self._get_version_key(prompt_id, default_version)
data = await self.kvstore.get(version_key)
if data:
prompt = self._deserialize_prompt(data)
prompts.append(prompt)
except (json.JSONDecodeError, KeyError):
continue
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
return ListPromptsResponse(data=prompts)
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
"""Get a prompt by its identifier and optional version."""
key = await self._get_prompt_key(prompt_id, version)
data = await self.kvstore.get(key)
if data is None:
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
return self._deserialize_prompt(data)
async def create_prompt(
self,
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create a new prompt."""
if variables is None:
variables = []
prompt_obj = Prompt(
prompt_id=Prompt.generate_prompt_id(),
prompt=prompt,
version=1,
variables=variables,
)
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
data = self._serialize_prompt(prompt_obj)
await self.kvstore.set(version_key, data)
default_key = self._get_default_key(prompt_obj.prompt_id)
await self.kvstore.set(default_key, str(prompt_obj.version))
return prompt_obj
async def update_prompt(
self,
prompt_id: str,
prompt: str,
version: int,
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update an existing prompt (increments version)."""
if version < 1:
raise ValueError("Version must be >= 1")
if variables is None:
variables = []
prompt_versions = await self.list_prompt_versions(prompt_id)
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
if version and latest_prompt.version != version:
raise ValueError(
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
)
current_version = latest_prompt.version if version is None else version
new_version = current_version + 1
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
version_key = self._get_version_key(prompt_id, str(new_version))
data = self._serialize_prompt(updated_prompt)
await self.kvstore.set(version_key, data)
if set_as_default:
await self.set_default_version(prompt_id, new_version)
return updated_prompt
async def delete_prompt(self, prompt_id: str) -> None:
"""Delete a prompt and all its versions."""
await self.get_prompt(prompt_id)
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
for key in keys:
await self.kvstore.delete(key)
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
"""List all versions of a specific prompt."""
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
default_version = None
prompts = []
for key in keys:
data = await self.kvstore.get(key)
if key.endswith(":default"):
default_version = data
else:
if data:
prompt_obj = self._deserialize_prompt(data)
prompts.append(prompt_obj)
if not prompts:
raise ValueError(f"Prompt {prompt_id} not found")
for prompt in prompts:
prompt.is_default = str(prompt.version) == default_version
prompts.sort(key=lambda x: x.version)
return ListPromptsResponse(data=prompts)
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
version_key = self._get_version_key(prompt_id, str(version))
data = await self.kvstore.get(version_key)
if data is None:
raise ValueError(f"Prompt {prompt_id} version {version} not found")
default_key = self._get_default_key(prompt_id)
await self.kvstore.set(default_key, str(version))
return self._deserialize_prompt(data)
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import ReferencePromptsConfig
from .reference import PromptServiceImpl
async def get_adapter_impl(config: ReferencePromptsConfig, _deps):
impl = PromptServiceImpl(config=config, deps=_deps)
await impl.initialize()
return impl
__all__ = ["ReferencePromptsConfig", "PromptServiceImpl", "get_adapter_impl"]

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.core.datatypes import StackRunConfig
class ReferencePromptsConfig(BaseModel):
"""Configuration for the built-in reference prompt service.
This provider stores prompts in the configured KVStore (SQLite, PostgreSQL, etc.)
as specified in the run configuration.
"""
run_config: StackRunConfig = Field(
description="Stack run configuration containing storage configuration"
)

View file

@ -0,0 +1,222 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
from .config import ReferencePromptsConfig
class PromptServiceImpl(Prompts):
"""Reference inline prompt service implementation using KVStore.
This provider stores prompts in the configured KVStore backend (SQLite, PostgreSQL, etc.)
and provides full CRUD operations with versioning support.
"""
def __init__(self, config: ReferencePromptsConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.kvstore: KVStore
async def initialize(self) -> None:
# Use prompts store reference from run config
prompts_ref = self.config.run_config.storage.stores.prompts
if not prompts_ref:
raise ValueError("storage.stores.prompts must be configured in run config")
self.kvstore = await kvstore_impl(prompts_ref)
def _get_default_key(self, prompt_id: str) -> str:
"""Get the KVStore key that stores the default version number."""
return f"prompts:v1:{prompt_id}:default"
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
"""Get the KVStore key for prompt data, returning default version if applicable."""
if version:
return self._get_version_key(prompt_id, str(version))
default_key = self._get_default_key(prompt_id)
resolved_version = await self.kvstore.get(default_key)
if resolved_version is None:
raise ValueError(f"Prompt {prompt_id}:default not found")
return self._get_version_key(prompt_id, resolved_version)
def _get_version_key(self, prompt_id: str, version: str) -> str:
"""Get the KVStore key for a specific prompt version."""
return f"prompts:v1:{prompt_id}:{version}"
def _get_list_key_prefix(self) -> str:
"""Get the key prefix for listing prompts."""
return "prompts:v1:"
def _serialize_prompt(self, prompt: Prompt) -> str:
"""Serialize a prompt to JSON string for storage."""
return json.dumps(
{
"prompt_id": prompt.prompt_id,
"prompt": prompt.prompt,
"version": prompt.version,
"variables": prompt.variables or [],
"is_default": prompt.is_default,
}
)
def _deserialize_prompt(self, data: str) -> Prompt:
"""Deserialize a prompt from JSON string."""
obj = json.loads(data)
return Prompt(
prompt_id=obj["prompt_id"],
prompt=obj["prompt"],
version=obj["version"],
variables=obj.get("variables", []),
is_default=obj.get("is_default", False),
)
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts (default versions only)."""
prefix = self._get_list_key_prefix()
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
prompts = []
for key in keys:
if key.endswith(":default"):
try:
default_version = await self.kvstore.get(key)
if default_version:
prompt_id = key.replace(prefix, "").replace(":default", "")
version_key = self._get_version_key(prompt_id, default_version)
data = await self.kvstore.get(version_key)
if data:
prompt = self._deserialize_prompt(data)
prompts.append(prompt)
except (json.JSONDecodeError, KeyError):
continue
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
return ListPromptsResponse(data=prompts)
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
"""Get a prompt by its identifier and optional version."""
key = await self._get_prompt_key(prompt_id, version)
data = await self.kvstore.get(key)
if data is None:
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
return self._deserialize_prompt(data)
async def create_prompt(
self,
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create a new prompt."""
if variables is None:
variables = []
prompt_obj = Prompt(
prompt_id=Prompt.generate_prompt_id(),
prompt=prompt,
version=1,
variables=variables,
)
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
data = self._serialize_prompt(prompt_obj)
await self.kvstore.set(version_key, data)
default_key = self._get_default_key(prompt_obj.prompt_id)
await self.kvstore.set(default_key, str(prompt_obj.version))
return prompt_obj
async def update_prompt(
self,
prompt_id: str,
prompt: str,
version: int,
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update an existing prompt (increments version)."""
if version < 1:
raise ValueError("Version must be >= 1")
if variables is None:
variables = []
prompt_versions = await self.list_prompt_versions(prompt_id)
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
if version and latest_prompt.version != version:
raise ValueError(
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
)
current_version = latest_prompt.version if version is None else version
new_version = current_version + 1
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
version_key = self._get_version_key(prompt_id, str(new_version))
data = self._serialize_prompt(updated_prompt)
await self.kvstore.set(version_key, data)
if set_as_default:
await self.set_default_version(prompt_id, new_version)
return updated_prompt
async def delete_prompt(self, prompt_id: str) -> None:
"""Delete a prompt and all its versions."""
await self.get_prompt(prompt_id)
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
for key in keys:
await self.kvstore.delete(key)
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
"""List all versions of a specific prompt."""
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
default_version = None
prompts = []
for key in keys:
data = await self.kvstore.get(key)
if key.endswith(":default"):
default_version = data
else:
if data:
prompt_obj = self._deserialize_prompt(data)
prompts.append(prompt_obj)
if not prompts:
raise ValueError(f"Prompt {prompt_id} not found")
for prompt in prompts:
prompt.is_default = str(prompt.version) == default_version
prompts.sort(key=lambda x: x.version)
return ListPromptsResponse(data=prompts)
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
version_key = self._get_version_key(prompt_id, str(version))
data = await self.kvstore.get(version_key)
if data is None:
raise ValueError(f"Prompt {prompt_id} version {version} not found")
default_key = self._get_default_key(prompt_id)
await self.kvstore.set(default_key, str(version))
return self._deserialize_prompt(data)
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack_api import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.prompts,
provider_type="inline::reference",
pip_packages=[],
module="llama_stack.providers.inline.prompts.reference",
config_class="llama_stack.providers.inline.prompts.reference.ReferencePromptsConfig",
description="Reference implementation storing prompts in KVStore (SQLite, PostgreSQL, etc.)",
),
RemoteProviderSpec(
api=Api.prompts,
adapter_type="mlflow",
provider_type="remote::mlflow",
pip_packages=["mlflow>=3.4.0"],
module="llama_stack.providers.remote.prompts.mlflow",
config_class="llama_stack.providers.remote.prompts.mlflow.MLflowPromptsConfig",
provider_data_validator="llama_stack.providers.remote.prompts.mlflow.config.MLflowProviderDataValidator",
description="MLflow Prompt Registry provider for centralized prompt management and versioning",
),
]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MLflowPromptsConfig
from .mlflow import MLflowPromptsAdapter
__all__ = ["MLflowPromptsConfig", "MLflowPromptsAdapter", "get_adapter_impl"]
async def get_adapter_impl(config: MLflowPromptsConfig, _deps):
"""Get the MLflow prompts adapter implementation."""
impl = MLflowPromptsAdapter(config=config)
await impl.initialize()
return impl

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Configuration for MLflow Prompt Registry provider.
This module defines the configuration schema for integrating Llama Stack
with MLflow's Prompt Registry for centralized prompt management and versioning.
"""
from typing import Any
from pydantic import BaseModel, Field, SecretStr, field_validator
from llama_stack_api import json_schema_type
class MLflowProviderDataValidator(BaseModel):
"""Validator for provider data from request headers.
This allows users to override the MLflow API token per request via
the x-llamastack-provider-data header:
{"mlflow_api_token": "your-token"}
"""
mlflow_api_token: str | None = Field(
default=None,
description="MLflow API token for authentication (overrides config)",
)
@json_schema_type
class MLflowPromptsConfig(BaseModel):
"""Configuration for MLflow Prompt Registry provider.
Credentials can be provided via:
1. Per-request provider data header (preferred for security)
2. Configuration auth_credential (fallback)
3. Environment variables set by MLflow (MLFLOW_TRACKING_TOKEN, etc.)
Attributes:
mlflow_tracking_uri: MLflow tracking server URI (e.g., http://localhost:5000, databricks)
mlflow_registry_uri: MLflow registry URI (optional, defaults to tracking_uri)
experiment_name: MLflow experiment name for prompt storage
auth_credential: MLflow API token for authentication (optional, can be overridden by provider data)
timeout_seconds: Timeout for MLflow API calls in seconds (default: 30)
"""
mlflow_tracking_uri: str = Field(
default="http://localhost:5000",
description="MLflow tracking server URI (e.g., http://localhost:5000, databricks, databricks://profile)",
)
mlflow_registry_uri: str | None = Field(
default=None,
description="MLflow registry URI (defaults to tracking_uri if not specified)",
)
experiment_name: str = Field(
default="llama-stack-prompts",
description="MLflow experiment name for prompt storage and organization",
)
auth_credential: SecretStr | None = Field(
default=None,
description="MLflow API token for authentication. Can be overridden via provider data header.",
)
timeout_seconds: int = Field(
default=30,
ge=1,
le=300,
description="Timeout for MLflow API calls in seconds (1-300)",
)
@classmethod
def sample_run_config(cls, mlflow_api_token: str = "${env.MLFLOW_TRACKING_TOKEN:=}", **kwargs) -> dict[str, Any]:
"""Generate sample configuration with environment variable substitution.
Args:
mlflow_api_token: MLflow API token (defaults to MLFLOW_TRACKING_TOKEN env var)
**kwargs: Additional configuration overrides
Returns:
Sample configuration dictionary
"""
return {
"mlflow_tracking_uri": kwargs.get("mlflow_tracking_uri", "http://localhost:5000"),
"experiment_name": kwargs.get("experiment_name", "llama-stack-prompts"),
"auth_credential": mlflow_api_token,
}
@field_validator("mlflow_tracking_uri")
@classmethod
def validate_tracking_uri(cls, v: str) -> str:
"""Validate tracking URI is not empty."""
if not v or not v.strip():
raise ValueError("mlflow_tracking_uri cannot be empty")
return v.strip()
@field_validator("experiment_name")
@classmethod
def validate_experiment_name(cls, v: str) -> str:
"""Validate experiment name is not empty."""
if not v or not v.strip():
raise ValueError("experiment_name cannot be empty")
return v.strip()

View file

@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""ID mapping utilities for MLflow Prompt Registry provider.
This module handles bidirectional mapping between Llama Stack's prompt_id format
(pmpt_<48-hex-chars>) and MLflow's name-based system (llama_prompt_<hex>).
"""
import re
class PromptIDMapper:
"""Handle bidirectional mapping between Llama Stack IDs and MLflow names.
Llama Stack uses prompt IDs in format: pmpt_<48-hex-chars>
MLflow uses string names, so we map to: llama_prompt_<48-hex-chars>
This ensures:
- Deterministic mapping (same ID always maps to same name)
- Reversible (can recover original ID from MLflow name)
- Unique (different IDs map to different names)
"""
# Regex pattern for Llama Stack prompt_id validation
PROMPT_ID_PATTERN = re.compile(r"^pmpt_[0-9a-f]{48}$")
# Prefix for MLflow prompt names managed by Llama Stack
MLFLOW_NAME_PREFIX = "llama_prompt_"
def to_mlflow_name(self, prompt_id: str) -> str:
"""Convert Llama Stack prompt_id to MLflow prompt name.
Args:
prompt_id: Llama Stack prompt ID (format: pmpt_<48-hex-chars>)
Returns:
MLflow prompt name (format: llama_prompt_<48-hex-chars>)
Raises:
ValueError: If prompt_id format is invalid
Example:
>>> mapper = PromptIDMapper()
>>> mapper.to_mlflow_name("pmpt_a1b2c3d4e5f6...")
"llama_prompt_a1b2c3d4e5f6..."
"""
if not self.PROMPT_ID_PATTERN.match(prompt_id):
raise ValueError(f"Invalid prompt_id format: {prompt_id}. Expected format: pmpt_<48-hex-chars>")
# Extract hex part (after "pmpt_" prefix)
hex_part = prompt_id.split("pmpt_")[1]
# Create MLflow name
return f"{self.MLFLOW_NAME_PREFIX}{hex_part}"
def to_llama_id(self, mlflow_name: str) -> str:
"""Convert MLflow prompt name to Llama Stack prompt_id.
Args:
mlflow_name: MLflow prompt name
Returns:
Llama Stack prompt ID (format: pmpt_<48-hex-chars>)
Raises:
ValueError: If name doesn't follow expected format
Example:
>>> mapper = PromptIDMapper()
>>> mapper.to_llama_id("llama_prompt_a1b2c3d4e5f6...")
"pmpt_a1b2c3d4e5f6..."
"""
if not mlflow_name.startswith(self.MLFLOW_NAME_PREFIX):
raise ValueError(
f"MLflow name '{mlflow_name}' does not start with expected prefix '{self.MLFLOW_NAME_PREFIX}'"
)
# Extract hex part
hex_part = mlflow_name[len(self.MLFLOW_NAME_PREFIX) :]
# Validate hex part length and characters
if len(hex_part) != 48:
raise ValueError(f"Invalid hex part length in MLflow name '{mlflow_name}'. Expected 48 characters.")
for char in hex_part:
if char not in "0123456789abcdef":
raise ValueError(
f"Invalid character '{char}' in hex part of MLflow name '{mlflow_name}'. "
"Expected lowercase hex characters [0-9a-f]."
)
return f"pmpt_{hex_part}"
def get_metadata_tags(self, prompt_id: str, variables: list[str] | None = None) -> dict[str, str]:
"""Generate MLflow tags with Llama Stack metadata.
Args:
prompt_id: Llama Stack prompt ID
variables: List of prompt variables (optional)
Returns:
Dictionary of MLflow tags for metadata storage
Example:
>>> mapper = PromptIDMapper()
>>> tags = mapper.get_metadata_tags("pmpt_abc123...", ["var1", "var2"])
>>> tags
{"llama_stack_id": "pmpt_abc123...", "llama_stack_managed": "true", "variables": "var1,var2"}
"""
tags = {
"llama_stack_id": prompt_id,
"llama_stack_managed": "true",
}
if variables:
# Store variables as comma-separated string
tags["variables"] = ",".join(variables)
return tags

View file

@ -0,0 +1,547 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""MLflow Prompt Registry provider implementation.
This module implements the Llama Stack Prompts protocol using MLflow's Prompt Registry
as the backend for centralized prompt management and versioning.
"""
import re
from typing import TYPE_CHECKING, Any
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.remote.prompts.mlflow.config import MLflowPromptsConfig
from llama_stack.providers.remote.prompts.mlflow.mapping import PromptIDMapper
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
# Try importing mlflow at module level
try:
import mlflow
from mlflow.client import MlflowClient
except ImportError:
# Fail gracefully when provider is instantiated during initialize()
mlflow = None
logger = get_logger(__name__)
class MLflowPromptsAdapter(NeedsRequestProviderData, Prompts):
"""MLflow Prompt Registry adapter for Llama Stack.
This adapter implements the Llama Stack Prompts protocol using MLflow's
Prompt Registry as the backend storage system. It handles:
- Bidirectional ID mapping (prompt_id <-> MLflow name)
- Version management via MLflow versioning
- Variable extraction from prompt templates
- Metadata storage in MLflow tags
- Default version management via MLflow aliases
- Credential management via provider data (backstopped by config)
Credentials can be provided via:
1. Per-request provider data header (preferred for security)
2. Configuration auth_credential (fallback)
3. Environment variables (MLFLOW_TRACKING_TOKEN, etc.)
Attributes:
config: MLflow provider configuration
mlflow_client: MLflow client instance
mapper: ID mapping utility
"""
def __init__(self, config: MLflowPromptsConfig):
"""Initialize MLflow prompts adapter.
Args:
config: MLflow provider configuration
"""
self.config = config
self.mlflow_client: "MlflowClient | None" = None
self.mapper = PromptIDMapper()
logger.info(
f"MLflowPromptsAdapter initialized: tracking_uri={config.mlflow_tracking_uri}, "
f"experiment={config.experiment_name}"
)
async def initialize(self) -> None:
"""Initialize MLflow client and set up experiment.
Sets up MLflow connection with optional authentication via token.
Token can be provided via config or will be read from environment variables
(MLFLOW_TRACKING_TOKEN, etc.) as per MLflow's standard behavior.
Raises:
ImportError: If mlflow package is not installed
Exception: If MLflow connection fails
"""
if mlflow is None:
raise ImportError(
"mlflow package is required for MLflow prompts provider. "
"Install with: pip install 'mlflow>=3.4.0'"
)
# Set MLflow URIs
mlflow.set_tracking_uri(self.config.mlflow_tracking_uri)
if self.config.mlflow_registry_uri:
mlflow.set_registry_uri(self.config.mlflow_registry_uri)
else:
# Default to tracking URI if registry not specified
mlflow.set_registry_uri(self.config.mlflow_tracking_uri)
# Set authentication token if provided in config
if self.config.auth_credential is not None:
import os
# MLflow reads MLFLOW_TRACKING_TOKEN from environment
os.environ["MLFLOW_TRACKING_TOKEN"] = self.config.auth_credential.get_secret_value()
logger.debug("Set MLFLOW_TRACKING_TOKEN from config auth_credential")
# Initialize client
self.mlflow_client = MlflowClient()
# Validate experiment exists (don't create during initialization)
try:
mlflow.set_experiment(self.config.experiment_name)
logger.info(f"Using MLflow experiment: {self.config.experiment_name}")
except Exception as e:
logger.warning(
f"Experiment '{self.config.experiment_name}' not found: {e}. "
f"It will be created automatically on first prompt creation."
)
def _ensure_experiment(self) -> None:
"""Ensure MLflow experiment exists, creating it if necessary.
This is called lazily on first write operation to avoid creating
external resources during initialization.
"""
try:
mlflow.set_experiment(self.config.experiment_name)
except Exception:
# Experiment doesn't exist, create it
try:
mlflow.create_experiment(self.config.experiment_name)
mlflow.set_experiment(self.config.experiment_name)
logger.info(f"Created MLflow experiment: {self.config.experiment_name}")
except Exception as e:
raise ValueError(
f"Failed to create experiment '{self.config.experiment_name}': {e}"
) from e
def _extract_variables(self, template: str) -> list[str]:
"""Extract variables from prompt template.
Extracts variables in {{ variable }} format from the template.
Args:
template: Prompt template string
Returns:
List of unique variable names in order of appearance
Example:
>>> adapter._extract_variables("Hello {{ name }}, your score is {{ score }}")
["name", "score"]
"""
if not template:
return []
# Find all {{ variable }} patterns
matches = re.findall(r"{{\s*(\w+)\s*}}", template)
# Return unique variables in order of appearance
seen = set()
variables = []
for var in matches:
if var not in seen:
variables.append(var)
seen.add(var)
return variables
async def create_prompt(
self,
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create a new prompt in MLflow registry.
Args:
prompt: Prompt template text with {{ variable }} placeholders
variables: List of variable names (auto-extracted if not provided)
Returns:
Created Prompt resource with prompt_id and version=1
Raises:
ValueError: If prompt validation fails
Exception: If MLflow registration fails
"""
# Ensure experiment exists (lazy creation on first write)
self._ensure_experiment()
# Auto-extract variables if not provided
if variables is None:
variables = self._extract_variables(prompt)
else:
# Validate declared variables match template
template_vars = set(self._extract_variables(prompt))
declared_vars = set(variables)
undeclared = template_vars - declared_vars
if undeclared:
raise ValueError(f"Template contains undeclared variables: {sorted(undeclared)}")
# Generate Llama Stack prompt_id
prompt_id = Prompt.generate_prompt_id()
# Convert to MLflow name
mlflow_name = self.mapper.to_mlflow_name(prompt_id)
# Prepare metadata tags
tags = self.mapper.get_metadata_tags(prompt_id, variables)
# Register in MLflow
try:
mlflow.genai.register_prompt(
name=mlflow_name,
template=prompt,
commit_message="Created via Llama Stack",
tags=tags,
)
logger.info(f"Created prompt {prompt_id} (MLflow: {mlflow_name})")
except Exception as e:
logger.error(f"Failed to register prompt in MLflow: {e}")
raise
# Set as default (first version is always default)
try:
mlflow.genai.set_prompt_alias(
name=mlflow_name,
version=1,
alias="default",
)
except Exception as e:
logger.warning(f"Failed to set default alias for {prompt_id}: {e}")
return Prompt(
prompt_id=prompt_id,
prompt=prompt,
version=1,
variables=variables,
is_default=True,
)
async def get_prompt(
self,
prompt_id: str,
version: int | None = None,
) -> Prompt:
"""Get prompt from MLflow registry.
Args:
prompt_id: Llama Stack prompt ID
version: Version number (defaults to default version)
Returns:
Prompt resource
Raises:
ValueError: If prompt not found
"""
mlflow_name = self.mapper.to_mlflow_name(prompt_id)
# Build MLflow URI
if version:
uri = f"prompts:/{mlflow_name}/{version}"
else:
uri = f"prompts:/{mlflow_name}@default"
# Load from MLflow
try:
mlflow_prompt = mlflow.genai.load_prompt(uri)
except Exception as e:
raise ValueError(f"Prompt {prompt_id} (version {version if version else 'default'}) not found: {e}") from e
# Extract template
template = mlflow_prompt.template if hasattr(mlflow_prompt, "template") else str(mlflow_prompt)
# Extract variables from template
variables = self._extract_variables(template)
# Get version number
prompt_version = 1
if hasattr(mlflow_prompt, "version"):
prompt_version = int(mlflow_prompt.version)
elif version:
prompt_version = version
# Check if this is the default version
is_default = await self._is_default_version(mlflow_name, prompt_version)
return Prompt(
prompt_id=prompt_id,
prompt=template,
version=prompt_version,
variables=variables,
is_default=is_default,
)
async def update_prompt(
self,
prompt_id: str,
prompt: str,
version: int,
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update prompt (creates new version in MLflow).
Args:
prompt_id: Llama Stack prompt ID
prompt: Updated prompt template
version: Current version being updated
variables: Updated variables list (auto-extracted if not provided)
set_as_default: Set new version as default
Returns:
Updated Prompt resource with incremented version
Raises:
ValueError: If current version not found or validation fails
"""
# Ensure experiment exists (edge case: updating prompts created outside Llama Stack)
self._ensure_experiment()
# Auto-extract variables if not provided
if variables is None:
variables = self._extract_variables(prompt)
else:
# Validate variables
template_vars = set(self._extract_variables(prompt))
declared_vars = set(variables)
undeclared = template_vars - declared_vars
if undeclared:
raise ValueError(f"Template contains undeclared variables: {sorted(undeclared)}")
mlflow_name = self.mapper.to_mlflow_name(prompt_id)
# Get all versions to determine the latest and next version number
versions_response = await self.list_prompt_versions(prompt_id)
if not versions_response.data:
raise ValueError(f"Prompt {prompt_id} not found")
max_version = max(p.version for p in versions_response.data)
# Verify the provided version is the latest
if version != max_version:
raise ValueError(
f"Version {version} is not the latest version. Use latest version {max_version} to update."
)
new_version = max_version + 1
# Prepare metadata tags
tags = self.mapper.get_metadata_tags(prompt_id, variables)
# Register new version in MLflow
try:
mlflow.genai.register_prompt(
name=mlflow_name,
template=prompt,
commit_message=f"Updated from version {version} via Llama Stack",
tags=tags,
)
logger.info(f"Updated prompt {prompt_id} to version {new_version}")
except Exception as e:
logger.error(f"Failed to update prompt in MLflow: {e}")
raise
# Set as default if requested
if set_as_default:
try:
mlflow.genai.set_prompt_alias(
name=mlflow_name,
version=new_version,
alias="default",
)
except Exception as e:
logger.warning(f"Failed to set default alias: {e}")
return Prompt(
prompt_id=prompt_id,
prompt=prompt,
version=new_version,
variables=variables,
is_default=set_as_default,
)
async def delete_prompt(self, prompt_id: str) -> None:
"""Delete prompt from MLflow registry.
Note: MLflow Prompt Registry does not support deletion of registered prompts.
This method will raise NotImplementedError.
Args:
prompt_id: Llama Stack prompt ID
Raises:
NotImplementedError: MLflow doesn't support prompt deletion
"""
# MLflow doesn't support deletion of registered prompts
# Options:
# 1. Raise NotImplementedError (current approach)
# 2. Mark as deleted with tag (soft delete)
# 3. Delete all versions individually (if API exists)
raise NotImplementedError(
"MLflow Prompt Registry does not support deletion. Consider using tags to mark prompts as archived/deleted."
)
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts (default versions only).
Returns:
ListPromptsResponse with default version of each prompt
Note:
Only lists prompts created/managed by Llama Stack
(those with llama_stack_managed=true tag)
"""
try:
# Search for Llama Stack managed prompts using metadata tags
prompts = mlflow.genai.search_prompts(filter_string="tag.llama_stack_managed='true'")
except Exception as e:
logger.error(f"Failed to search prompts in MLflow: {e}")
return ListPromptsResponse(data=[])
llama_prompts = []
for mlflow_prompt in prompts:
try:
# Convert MLflow name to Llama Stack ID
prompt_id = self.mapper.to_llama_id(mlflow_prompt.name)
# Get default version
llama_prompt = await self.get_prompt(prompt_id)
llama_prompts.append(llama_prompt)
except (ValueError, Exception) as e:
# Skip prompts that can't be converted or retrieved
logger.warning(f"Skipping prompt {mlflow_prompt.name}: {e}")
continue
# Sort by prompt_id
llama_prompts.sort(key=lambda p: p.prompt_id, reverse=True)
return ListPromptsResponse(data=llama_prompts)
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
"""List all versions of a specific prompt.
Args:
prompt_id: Llama Stack prompt ID
Returns:
ListPromptsResponse with all versions of the prompt
Raises:
ValueError: If prompt not found
"""
# MLflow doesn't have a direct "list versions" API for prompts
# We need to iterate and try to load each version
versions = []
version_num = 1
max_attempts = 100 # Safety limit
while version_num <= max_attempts:
try:
prompt = await self.get_prompt(prompt_id, version_num)
versions.append(prompt)
version_num += 1
except ValueError:
# No more versions
break
except Exception as e:
logger.warning(f"Error loading version {version_num} of {prompt_id}: {e}")
break
if not versions:
raise ValueError(f"Prompt {prompt_id} not found")
# Sort by version number
versions.sort(key=lambda p: p.version)
return ListPromptsResponse(data=versions)
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
"""Set default version using MLflow alias.
Args:
prompt_id: Llama Stack prompt ID
version: Version number to set as default
Returns:
Prompt resource with is_default=True
Raises:
ValueError: If version not found
"""
# Ensure experiment exists (edge case: managing prompts created outside Llama Stack)
self._ensure_experiment()
mlflow_name = self.mapper.to_mlflow_name(prompt_id)
# Verify version exists
try:
prompt = await self.get_prompt(prompt_id, version)
except ValueError as e:
raise ValueError(f"Cannot set default: {e}") from e
# Set "default" alias in MLflow
try:
mlflow.genai.set_prompt_alias(
name=mlflow_name,
version=version,
alias="default",
)
logger.info(f"Set version {version} as default for {prompt_id}")
except Exception as e:
logger.error(f"Failed to set default version: {e}")
raise
# Update is_default flag
prompt.is_default = True
return prompt
async def _is_default_version(self, mlflow_name: str, version: int) -> bool:
"""Check if a version is the default version.
Args:
mlflow_name: MLflow prompt name
version: Version number
Returns:
True if this version is the default, False otherwise
"""
try:
# Try to load with @default alias
default_uri = f"prompts:/{mlflow_name}@default"
default_prompt = mlflow.genai.load_prompt(default_uri)
# Get default version number
default_version = 1
if hasattr(default_prompt, "version"):
default_version = int(default_prompt.version)
return version == default_version
except Exception:
# If default doesn't exist or can't be determined, assume False
return False
async def shutdown(self) -> None:
"""Cleanup resources (no-op for MLflow)."""
pass

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Integration tests for remote prompts providers."""

View file

@ -0,0 +1,274 @@
# MLflow Prompts Provider - Integration Tests
This directory contains integration tests for the MLflow Prompts Provider. These tests require a running MLflow server.
## Prerequisites
1. **MLflow installed**: `pip install 'mlflow>=3.4.0'` (or `uv pip install 'mlflow>=3.4.0'`)
2. **MLflow server running**: See setup instructions below
3. **Test dependencies**: `uv sync --group test`
## Quick Start
### 1. Start MLflow Server
```bash
# Start MLflow server on localhost:5555
mlflow server --host 127.0.0.1 --port 5555
# Keep this terminal open - server will continue running
```
### 2. Run Integration Tests
In a separate terminal:
```bash
# Set MLflow URI (optional - defaults to localhost:5555)
export MLFLOW_TRACKING_URI=http://localhost:5555
# Run all integration tests
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
# Run specific test
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/test_end_to_end.py::TestMLflowPromptsEndToEnd::test_create_and_retrieve_prompt
```
### 3. Run Manual Test Script (Optional)
For quick validation without pytest:
```bash
# Run manual test script
uv run python scripts/test_mlflow_prompts_manual.py
# View output in MLflow UI
open http://localhost:5555
```
## Test Organization
### Integration Tests (`test_end_to_end.py`)
Comprehensive end-to-end tests covering:
- ✅ Create and retrieve prompts
- ✅ Update prompts (version management)
- ✅ List prompts (default versions only)
- ✅ List all versions of a prompt
- ✅ Set default version
- ✅ Variable auto-extraction
- ✅ Variable validation
- ✅ Error handling (not found, wrong version, etc.)
- ✅ Complex templates with multiple variables
- ✅ Edge cases (empty templates, no variables, etc.)
**Total**: 17 test scenarios
### Manual Test Script (`scripts/test_mlflow_prompts_manual.py`)
Interactive test script with verbose output for:
- Server connectivity check
- Provider initialization
- Basic CRUD operations
- Variable extraction
- Statistics retrieval
## Configuration
### MLflow Server Options
**Local (default)**:
```bash
mlflow server --host 127.0.0.1 --port 5555
```
**Remote server**:
```bash
export MLFLOW_TRACKING_URI=http://mlflow.example.com:5000
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
```
**Databricks**:
```bash
export MLFLOW_TRACKING_URI=databricks
export MLFLOW_REGISTRY_URI=databricks://profile
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
```
### Test Timeout
Tests have a default timeout of 30 seconds per MLflow operation. Adjust in `conftest.py`:
```python
MLflowPromptsConfig(
mlflow_tracking_uri=mlflow_tracking_uri,
timeout_seconds=60, # Increase for slow connections
)
```
## Fixtures
### `mlflow_adapter`
Basic adapter for simple tests:
```python
async def test_something(mlflow_adapter):
prompt = await mlflow_adapter.create_prompt(...)
# Test continues...
```
### `mlflow_adapter_with_cleanup`
Adapter with automatic cleanup tracking:
```python
async def test_something(mlflow_adapter_with_cleanup):
# Creates are tracked and attempted cleanup on teardown
prompt = await mlflow_adapter_with_cleanup.create_prompt(...)
```
**Note**: MLflow doesn't support deletion, so cleanup is best-effort.
## Troubleshooting
### Server Not Available
**Symptom**:
```
SKIPPED [1] conftest.py:35: MLflow server not available at http://localhost:5555
```
**Solution**:
```bash
# Start MLflow server
mlflow server --host 127.0.0.1 --port 5555
# Verify it's running
curl http://localhost:5555/health
```
### Connection Timeout
**Symptom**:
```
requests.exceptions.Timeout: ...
```
**Solutions**:
1. Check MLflow server is responsive: `curl http://localhost:5555/health`
2. Increase timeout in `conftest.py`: `timeout_seconds=60`
3. Check firewall/network settings
### Import Errors
**Symptom**:
```
ModuleNotFoundError: No module named 'mlflow'
```
**Solution**:
```bash
uv pip install 'mlflow>=3.4.0'
```
### Permission Errors
**Symptom**:
```
PermissionError: [Errno 13] Permission denied: '...'
```
**Solution**:
- Ensure MLflow has write access to its storage directory
- Check file permissions on `mlruns/` directory
### Test Isolation Issues
**Issue**: Tests may interfere with each other if using same prompt IDs
**Solution**: Each test creates new prompts with unique IDs (generated by `Prompt.generate_prompt_id()`). If needed, use `mlflow_adapter_with_cleanup` fixture.
## Viewing Results
### MLflow UI
1. Start MLflow server (if not already running):
```bash
mlflow server --host 127.0.0.1 --port 5555
```
2. Open in browser:
```
http://localhost:5555
```
3. Navigate to experiment `test-llama-stack-prompts`
4. View registered prompts and their versions
### Test Output
Run with verbose output to see detailed test execution:
```bash
uv run --group test pytest -vv tests/integration/providers/remote/prompts/mlflow/
```
## CI/CD Integration
To run tests in CI/CD pipelines:
```yaml
# Example GitHub Actions workflow
- name: Start MLflow server
run: |
mlflow server --host 127.0.0.1 --port 5555 &
sleep 5 # Wait for server to start
- name: Wait for MLflow
run: |
timeout 30 bash -c 'until curl -s http://localhost:5555/health; do sleep 1; done'
- name: Run integration tests
env:
MLFLOW_TRACKING_URI: http://localhost:5555
run: |
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
```
## Performance
### Expected Test Duration
- **Individual test**: ~1-5 seconds
- **Full suite** (17 tests): ~30-60 seconds
- **Manual script**: ~10-15 seconds
### Optimization Tips
1. Use local MLflow server (faster than remote)
2. Run tests in parallel (if safe):
```bash
uv run --group test pytest -n auto tests/integration/providers/remote/prompts/mlflow/
```
3. Skip integration tests in development:
```bash
uv run --group dev pytest -sv tests/unit/
```
## Coverage
Integration tests provide coverage for:
- ✅ Real MLflow API calls
- ✅ Network communication
- ✅ Serialization/deserialization
- ✅ MLflow server responses
- ✅ Version management
- ✅ Alias handling
- ✅ Tag storage and retrieval
Combined with unit tests, achieves **>95% code coverage**.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Integration tests for MLflow prompts provider."""

View file

@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Fixtures for MLflow integration tests.
These tests require a running MLflow server. Set the MLFLOW_TRACKING_URI
environment variable to point to your MLflow server, or the tests will
attempt to use http://localhost:5555.
To run tests:
# Start MLflow server (in separate terminal)
mlflow server --host 127.0.0.1 --port 5555
# Run integration tests
MLFLOW_TRACKING_URI=http://localhost:5555 \
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
"""
import os
import pytest
from llama_stack.providers.remote.prompts.mlflow import MLflowPromptsAdapter
from llama_stack.providers.remote.prompts.mlflow.config import MLflowPromptsConfig
@pytest.fixture(scope="session")
def mlflow_tracking_uri():
"""Get MLflow tracking URI from environment or use default."""
return os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5555")
@pytest.fixture(scope="session")
def mlflow_server_available(mlflow_tracking_uri):
"""Verify MLflow server is running and accessible.
Skips all tests if server is not available.
"""
try:
import requests
response = requests.get(f"{mlflow_tracking_uri}/health", timeout=5)
if response.status_code != 200:
pytest.skip(f"MLflow server at {mlflow_tracking_uri} returned status {response.status_code}")
except ImportError:
pytest.skip("requests package not installed - install with: pip install requests")
except requests.exceptions.ConnectionError:
pytest.skip(
f"MLflow server not available at {mlflow_tracking_uri}. "
"Start with: mlflow server --host 127.0.0.1 --port 5555"
)
except requests.exceptions.Timeout:
pytest.skip(f"MLflow server at {mlflow_tracking_uri} timed out")
except Exception as e:
pytest.skip(f"Failed to check MLflow server availability: {e}")
return True
@pytest.fixture
async def mlflow_config(mlflow_tracking_uri, mlflow_server_available):
"""Create MLflow configuration for testing."""
return MLflowPromptsConfig(
mlflow_tracking_uri=mlflow_tracking_uri,
experiment_name="test-llama-stack-prompts",
timeout_seconds=30,
)
@pytest.fixture
async def mlflow_adapter(mlflow_config):
"""Create and initialize MLflow adapter for testing.
This fixture creates a new adapter instance for each test.
The adapter connects to the MLflow server specified in the config.
"""
adapter = MLflowPromptsAdapter(config=mlflow_config)
await adapter.initialize()
yield adapter
# Cleanup: shutdown adapter
await adapter.shutdown()
@pytest.fixture
async def mlflow_adapter_with_cleanup(mlflow_config):
"""Create MLflow adapter with automatic cleanup after test.
This fixture is useful for tests that create prompts and want them
automatically cleaned up (though MLflow doesn't support deletion,
so cleanup is best-effort).
"""
adapter = MLflowPromptsAdapter(config=mlflow_config)
await adapter.initialize()
created_prompt_ids = []
# Provide adapter and tracking list
class AdapterWithTracking:
def __init__(self, adapter_instance):
self.adapter = adapter_instance
self.created_ids = created_prompt_ids
async def create_prompt(self, *args, **kwargs):
prompt = await self.adapter.create_prompt(*args, **kwargs)
self.created_ids.append(prompt.prompt_id)
return prompt
def __getattr__(self, name):
return getattr(self.adapter, name)
tracked_adapter = AdapterWithTracking(adapter)
yield tracked_adapter
# Cleanup: attempt to delete created prompts
# Note: MLflow doesn't support deletion, so this is a no-op
# but we keep it for future compatibility
for prompt_id in created_prompt_ids:
try:
await adapter.delete_prompt(prompt_id)
except NotImplementedError:
# Expected - MLflow doesn't support deletion
pass
except Exception:
# Ignore cleanup errors
pass
await adapter.shutdown()

View file

@ -0,0 +1,350 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""End-to-end integration tests for MLflow prompts provider.
These tests require a running MLflow server. See conftest.py for setup instructions.
"""
import pytest
class TestMLflowPromptsEndToEnd:
"""End-to-end tests for MLflow prompts provider."""
async def test_create_and_retrieve_prompt(self, mlflow_adapter):
"""Test creating a prompt and retrieving it by ID."""
# Create prompt with variables
created = await mlflow_adapter.create_prompt(
prompt="Summarize the following text in {{ num_sentences }} sentences: {{ text }}",
variables=["num_sentences", "text"],
)
# Verify created prompt
assert created.prompt_id.startswith("pmpt_")
assert len(created.prompt_id) == 53 # "pmpt_" + 48 hex chars
assert created.version == 1
assert created.is_default is True
assert set(created.variables) == {"num_sentences", "text"}
assert "{{ num_sentences }}" in created.prompt
assert "{{ text }}" in created.prompt
# Retrieve prompt by ID (should get default version)
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
assert retrieved.prompt_id == created.prompt_id
assert retrieved.prompt == created.prompt
assert retrieved.version == created.version
assert set(retrieved.variables) == set(created.variables)
assert retrieved.is_default is True
# Retrieve specific version
retrieved_v1 = await mlflow_adapter.get_prompt(created.prompt_id, version=1)
assert retrieved_v1.prompt_id == created.prompt_id
assert retrieved_v1.version == 1
async def test_update_prompt_creates_new_version(self, mlflow_adapter):
"""Test that updating a prompt creates a new version."""
# Create initial prompt (version 1)
v1 = await mlflow_adapter.create_prompt(
prompt="Original prompt with {{ variable }}",
variables=["variable"],
)
assert v1.version == 1
assert v1.is_default is True
# Update prompt (should create version 2)
v2 = await mlflow_adapter.update_prompt(
prompt_id=v1.prompt_id,
prompt="Updated prompt with {{ variable }}",
version=1,
variables=["variable"],
set_as_default=True,
)
assert v2.prompt_id == v1.prompt_id
assert v2.version == 2
assert v2.is_default is True
assert "Updated" in v2.prompt
# Verify both versions exist
versions_response = await mlflow_adapter.list_prompt_versions(v1.prompt_id)
versions = versions_response.data
assert len(versions) >= 2
assert any(v.version == 1 for v in versions)
assert any(v.version == 2 for v in versions)
# Verify version 1 still exists
v1_retrieved = await mlflow_adapter.get_prompt(v1.prompt_id, version=1)
assert "Original" in v1_retrieved.prompt
assert v1_retrieved.is_default is False # No longer default
# Verify version 2 is default
default = await mlflow_adapter.get_prompt(v1.prompt_id)
assert default.version == 2
assert "Updated" in default.prompt
async def test_list_prompts_returns_defaults_only(self, mlflow_adapter):
"""Test that list_prompts returns only default versions."""
# Create multiple prompts
p1 = await mlflow_adapter.create_prompt(
prompt="Prompt 1 with {{ var }}",
variables=["var"],
)
p2 = await mlflow_adapter.create_prompt(
prompt="Prompt 2 with {{ var }}",
variables=["var"],
)
# Update first prompt (creates version 2)
await mlflow_adapter.update_prompt(
prompt_id=p1.prompt_id,
prompt="Prompt 1 updated with {{ var }}",
version=1,
variables=["var"],
set_as_default=True,
)
# List all prompts
response = await mlflow_adapter.list_prompts()
prompts = response.data
# Should contain at least our 2 prompts
assert len(prompts) >= 2
# Find our prompts in the list
p1_in_list = next((p for p in prompts if p.prompt_id == p1.prompt_id), None)
p2_in_list = next((p for p in prompts if p.prompt_id == p2.prompt_id), None)
assert p1_in_list is not None
assert p2_in_list is not None
# p1 should be version 2 (updated version is default)
assert p1_in_list.version == 2
assert p1_in_list.is_default is True
# p2 should be version 1 (original is still default)
assert p2_in_list.version == 1
assert p2_in_list.is_default is True
async def test_list_prompt_versions(self, mlflow_adapter):
"""Test listing all versions of a specific prompt."""
# Create prompt
v1 = await mlflow_adapter.create_prompt(
prompt="Version 1 {{ var }}",
variables=["var"],
)
# Create multiple versions
_v2 = await mlflow_adapter.update_prompt(
prompt_id=v1.prompt_id,
prompt="Version 2 {{ var }}",
version=1,
variables=["var"],
)
_v3 = await mlflow_adapter.update_prompt(
prompt_id=v1.prompt_id,
prompt="Version 3 {{ var }}",
version=2,
variables=["var"],
)
# List all versions
versions_response = await mlflow_adapter.list_prompt_versions(v1.prompt_id)
versions = versions_response.data
# Should have 3 versions
assert len(versions) == 3
# Verify versions are sorted by version number
assert versions[0].version == 1
assert versions[1].version == 2
assert versions[2].version == 3
# Verify content
assert "Version 1" in versions[0].prompt
assert "Version 2" in versions[1].prompt
assert "Version 3" in versions[2].prompt
# Only latest should be default
assert versions[0].is_default is False
assert versions[1].is_default is False
assert versions[2].is_default is True
async def test_set_default_version(self, mlflow_adapter):
"""Test changing which version is the default."""
# Create prompt and update it
v1 = await mlflow_adapter.create_prompt(
prompt="Version 1 {{ var }}",
variables=["var"],
)
_v2 = await mlflow_adapter.update_prompt(
prompt_id=v1.prompt_id,
prompt="Version 2 {{ var }}",
version=1,
variables=["var"],
)
# At this point, _v2 is default
default = await mlflow_adapter.get_prompt(v1.prompt_id)
assert default.version == 2
# Set v1 as default
updated = await mlflow_adapter.set_default_version(v1.prompt_id, 1)
assert updated.version == 1
assert updated.is_default is True
# Verify default changed
default = await mlflow_adapter.get_prompt(v1.prompt_id)
assert default.version == 1
assert "Version 1" in default.prompt
async def test_variable_auto_extraction(self, mlflow_adapter):
"""Test automatic variable extraction from template."""
# Create prompt without explicitly specifying variables
created = await mlflow_adapter.create_prompt(
prompt="Extract {{ entity }} from {{ text }} in {{ format }} format",
)
# Should auto-extract all variables
assert set(created.variables) == {"entity", "text", "format"}
# Retrieve and verify
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
assert set(retrieved.variables) == {"entity", "text", "format"}
async def test_variable_validation(self, mlflow_adapter):
"""Test that variable validation works correctly."""
# Should fail: template has undeclared variable
with pytest.raises(ValueError, match="undeclared variables"):
await mlflow_adapter.create_prompt(
prompt="Template with {{ var1 }} and {{ var2 }}",
variables=["var1"], # Missing var2
)
async def test_prompt_not_found(self, mlflow_adapter):
"""Test error handling when prompt doesn't exist."""
fake_id = "pmpt_" + "0" * 48
with pytest.raises(ValueError, match="not found"):
await mlflow_adapter.get_prompt(fake_id)
async def test_version_not_found(self, mlflow_adapter):
"""Test error handling when version doesn't exist."""
# Create prompt (version 1)
created = await mlflow_adapter.create_prompt(
prompt="Test {{ var }}",
variables=["var"],
)
# Try to get non-existent version
with pytest.raises(ValueError, match="not found"):
await mlflow_adapter.get_prompt(created.prompt_id, version=999)
async def test_update_wrong_version(self, mlflow_adapter):
"""Test that updating with wrong version fails."""
# Create prompt (version 1)
created = await mlflow_adapter.create_prompt(
prompt="Test {{ var }}",
variables=["var"],
)
# Try to update with wrong version number
with pytest.raises(ValueError, match="not the latest"):
await mlflow_adapter.update_prompt(
prompt_id=created.prompt_id,
prompt="Updated {{ var }}",
version=999, # Wrong version
variables=["var"],
)
async def test_delete_not_supported(self, mlflow_adapter):
"""Test that deletion raises NotImplementedError."""
# Create prompt
created = await mlflow_adapter.create_prompt(
prompt="Test {{ var }}",
variables=["var"],
)
# Try to delete (should fail with NotImplementedError)
with pytest.raises(NotImplementedError, match="does not support deletion"):
await mlflow_adapter.delete_prompt(created.prompt_id)
# Verify prompt still exists
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
assert retrieved.prompt_id == created.prompt_id
async def test_complex_template_with_multiple_variables(self, mlflow_adapter):
"""Test prompt with complex template and multiple variables."""
template = """You are a {{ role }} assistant specialized in {{ domain }}.
Task: {{ task }}
Context:
{{ context }}
Instructions:
1. {{ instruction1 }}
2. {{ instruction2 }}
3. {{ instruction3 }}
Output format: {{ output_format }}
"""
# Create with auto-extraction
created = await mlflow_adapter.create_prompt(prompt=template)
# Should extract all variables
expected_vars = {
"role",
"domain",
"task",
"context",
"instruction1",
"instruction2",
"instruction3",
"output_format",
}
assert set(created.variables) == expected_vars
# Retrieve and verify template preserved
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
assert retrieved.prompt == template
async def test_empty_template(self, mlflow_adapter):
"""Test handling of empty template."""
# Create prompt with empty template
created = await mlflow_adapter.create_prompt(
prompt="",
variables=[],
)
assert created.prompt == ""
assert created.variables == []
# Retrieve and verify
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
assert retrieved.prompt == ""
async def test_template_with_no_variables(self, mlflow_adapter):
"""Test template without any variables."""
template = "This is a static prompt with no variables."
created = await mlflow_adapter.create_prompt(prompt=template)
assert created.prompt == template
assert created.variables == []
# Retrieve and verify
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
assert retrieved.prompt == template
assert retrieved.variables == []

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for remote prompts providers."""

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for MLflow prompts provider."""

View file

@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for MLflow prompts provider configuration."""
import pytest
from pydantic import SecretStr, ValidationError
from llama_stack.providers.remote.prompts.mlflow.config import (
MLflowPromptsConfig,
MLflowProviderDataValidator,
)
class TestMLflowPromptsConfig:
"""Tests for MLflowPromptsConfig model."""
def test_default_config(self):
"""Test default configuration values."""
config = MLflowPromptsConfig()
assert config.mlflow_tracking_uri == "http://localhost:5000"
assert config.mlflow_registry_uri is None
assert config.experiment_name == "llama-stack-prompts"
assert config.auth_credential is None
assert config.timeout_seconds == 30
def test_custom_config(self):
"""Test custom configuration values."""
config = MLflowPromptsConfig(
mlflow_tracking_uri="http://mlflow.example.com:8080",
mlflow_registry_uri="http://registry.example.com:8080",
experiment_name="my-prompts",
auth_credential=SecretStr("my-token"),
timeout_seconds=60,
)
assert config.mlflow_tracking_uri == "http://mlflow.example.com:8080"
assert config.mlflow_registry_uri == "http://registry.example.com:8080"
assert config.experiment_name == "my-prompts"
assert config.auth_credential.get_secret_value() == "my-token"
assert config.timeout_seconds == 60
def test_databricks_uri(self):
"""Test Databricks URI configuration."""
config = MLflowPromptsConfig(
mlflow_tracking_uri="databricks",
mlflow_registry_uri="databricks://profile",
)
assert config.mlflow_tracking_uri == "databricks"
assert config.mlflow_registry_uri == "databricks://profile"
def test_tracking_uri_validation(self):
"""Test tracking URI validation."""
# Empty string rejected
with pytest.raises(ValidationError, match="mlflow_tracking_uri cannot be empty"):
MLflowPromptsConfig(mlflow_tracking_uri="")
# Whitespace-only rejected
with pytest.raises(ValidationError, match="mlflow_tracking_uri cannot be empty"):
MLflowPromptsConfig(mlflow_tracking_uri=" ")
# Whitespace is stripped
config = MLflowPromptsConfig(mlflow_tracking_uri=" http://localhost:5000 ")
assert config.mlflow_tracking_uri == "http://localhost:5000"
def test_experiment_name_validation(self):
"""Test experiment name validation."""
# Empty string rejected
with pytest.raises(ValidationError, match="experiment_name cannot be empty"):
MLflowPromptsConfig(experiment_name="")
# Whitespace-only rejected
with pytest.raises(ValidationError, match="experiment_name cannot be empty"):
MLflowPromptsConfig(experiment_name=" ")
# Whitespace is stripped
config = MLflowPromptsConfig(experiment_name=" my-experiment ")
assert config.experiment_name == "my-experiment"
def test_timeout_validation(self):
"""Test timeout range validation."""
# Too low rejected
with pytest.raises(ValidationError):
MLflowPromptsConfig(timeout_seconds=0)
with pytest.raises(ValidationError):
MLflowPromptsConfig(timeout_seconds=-1)
# Too high rejected
with pytest.raises(ValidationError):
MLflowPromptsConfig(timeout_seconds=301)
# Boundary values accepted
config_min = MLflowPromptsConfig(timeout_seconds=1)
assert config_min.timeout_seconds == 1
config_max = MLflowPromptsConfig(timeout_seconds=300)
assert config_max.timeout_seconds == 300
def test_sample_run_config(self):
"""Test sample_run_config generates valid configuration."""
# Default environment variable
sample = MLflowPromptsConfig.sample_run_config()
assert sample["mlflow_tracking_uri"] == "http://localhost:5000"
assert sample["experiment_name"] == "llama-stack-prompts"
assert sample["auth_credential"] == "${env.MLFLOW_TRACKING_TOKEN:=}"
# Custom values
sample = MLflowPromptsConfig.sample_run_config(
mlflow_api_token="test-token",
mlflow_tracking_uri="http://custom:5000",
)
assert sample["mlflow_tracking_uri"] == "http://custom:5000"
assert sample["auth_credential"] == "test-token"
class TestMLflowProviderDataValidator:
"""Tests for MLflowProviderDataValidator."""
def test_provider_data_validator(self):
"""Test provider data validator with and without token."""
# With token
validator = MLflowProviderDataValidator(mlflow_api_token="test-token-123")
assert validator.mlflow_api_token == "test-token-123"
# Without token
validator = MLflowProviderDataValidator()
assert validator.mlflow_api_token is None
# From dictionary
data = {"mlflow_api_token": "secret-token"}
validator = MLflowProviderDataValidator(**data)
assert validator.mlflow_api_token == "secret-token"

View file

@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Unit tests for MLflow prompts ID mapping utilities."""
import pytest
from llama_stack.providers.remote.prompts.mlflow.mapping import PromptIDMapper
class TestPromptIDMapper:
"""Tests for PromptIDMapper class."""
@pytest.fixture
def mapper(self):
"""Create ID mapper instance."""
return PromptIDMapper()
def test_to_mlflow_name_valid(self, mapper):
"""Test converting valid prompt_id to MLflow name."""
prompt_id = "pmpt_" + "a" * 48
mlflow_name = mapper.to_mlflow_name(prompt_id)
assert mlflow_name == "llama_prompt_" + "a" * 48
assert mlflow_name.startswith(mapper.MLFLOW_NAME_PREFIX)
def test_to_mlflow_name_invalid(self, mapper):
"""Test conversion fails with invalid inputs."""
# Invalid prefix
with pytest.raises(ValueError, match="Invalid prompt_id format"):
mapper.to_mlflow_name("invalid_" + "a" * 48)
# Wrong length
with pytest.raises(ValueError, match="Invalid prompt_id format"):
mapper.to_mlflow_name("pmpt_" + "a" * 47)
# Invalid hex characters
with pytest.raises(ValueError, match="Invalid prompt_id format"):
mapper.to_mlflow_name("pmpt_" + "g" * 48)
def test_to_llama_id_valid(self, mapper):
"""Test converting valid MLflow name to prompt_id."""
mlflow_name = "llama_prompt_" + "b" * 48
prompt_id = mapper.to_llama_id(mlflow_name)
assert prompt_id == "pmpt_" + "b" * 48
assert prompt_id.startswith("pmpt_")
def test_to_llama_id_invalid(self, mapper):
"""Test conversion fails with invalid inputs."""
# Invalid prefix
with pytest.raises(ValueError, match="does not start with expected prefix"):
mapper.to_llama_id("wrong_prefix_" + "a" * 48)
# Wrong length
with pytest.raises(ValueError, match="Invalid hex part length"):
mapper.to_llama_id("llama_prompt_" + "a" * 47)
# Invalid hex characters
with pytest.raises(ValueError, match="Invalid character"):
mapper.to_llama_id("llama_prompt_" + "G" * 48)
def test_bidirectional_conversion(self, mapper):
"""Test bidirectional conversion preserves IDs."""
original_id = "pmpt_0123456789abcdef" + "a" * 32
# Convert to MLflow name and back
mlflow_name = mapper.to_mlflow_name(original_id)
recovered_id = mapper.to_llama_id(mlflow_name)
assert recovered_id == original_id
def test_get_metadata_tags_with_variables(self, mapper):
"""Test metadata tags generation with variables."""
prompt_id = "pmpt_" + "c" * 48
variables = ["var1", "var2", "var3"]
tags = mapper.get_metadata_tags(prompt_id, variables)
assert tags["llama_stack_id"] == prompt_id
assert tags["llama_stack_managed"] == "true"
assert tags["variables"] == "var1,var2,var3"
def test_get_metadata_tags_without_variables(self, mapper):
"""Test metadata tags generation without variables."""
prompt_id = "pmpt_" + "d" * 48
tags = mapper.get_metadata_tags(prompt_id)
assert tags["llama_stack_id"] == prompt_id
assert tags["llama_stack_managed"] == "true"
assert "variables" not in tags