mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
Merge 5c4da04f29 into 4237eb4aaa
This commit is contained in:
commit
d5836c3b5a
24 changed files with 3594 additions and 225 deletions
125
.github/workflows/integration-mlflow-tests.yml
vendored
Normal file
125
.github/workflows/integration-mlflow-tests.yml
vendored
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
name: MLflow Prompts Integration Tests
|
||||||
|
|
||||||
|
run-name: Run the integration test suite with MLflow Prompt Registry provider
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
|
paths:
|
||||||
|
- 'src/llama_stack/providers/remote/prompts/mlflow/**'
|
||||||
|
- 'tests/integration/providers/remote/prompts/mlflow/**'
|
||||||
|
- 'tests/unit/providers/remote/prompts/mlflow/**'
|
||||||
|
- 'uv.lock'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '.github/workflows/integration-mlflow-tests.yml' # This workflow
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * *' # Daily at 12 AM UTC
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-mlflow:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
uses: ./.github/actions/setup-runner
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Setup MLflow Server
|
||||||
|
run: |
|
||||||
|
docker run --rm -d --pull always \
|
||||||
|
--name mlflow \
|
||||||
|
-p 5555:5555 \
|
||||||
|
ghcr.io/mlflow/mlflow:latest \
|
||||||
|
mlflow server \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 5555 \
|
||||||
|
--backend-store-uri sqlite:///mlflow.db \
|
||||||
|
--default-artifact-root ./mlruns
|
||||||
|
|
||||||
|
- name: Wait for MLflow to be ready
|
||||||
|
run: |
|
||||||
|
echo "Waiting for MLflow to be ready..."
|
||||||
|
for i in {1..60}; do
|
||||||
|
if curl -s http://localhost:5555/health | grep -q '"status": "OK"'; then
|
||||||
|
echo "MLflow is ready!"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
echo "Not ready yet... ($i/60)"
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
echo "MLflow failed to start"
|
||||||
|
docker logs mlflow
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
- name: Verify MLflow API
|
||||||
|
run: |
|
||||||
|
echo "Testing MLflow API..."
|
||||||
|
curl -X GET http://localhost:5555/api/2.0/mlflow/experiments/list
|
||||||
|
echo ""
|
||||||
|
echo "MLflow API is responding!"
|
||||||
|
|
||||||
|
- name: Build Llama Stack
|
||||||
|
run: |
|
||||||
|
uv run --no-sync llama stack list-deps ci-tests | xargs -L1 uv pip install
|
||||||
|
|
||||||
|
- name: Install MLflow Python client
|
||||||
|
run: |
|
||||||
|
uv pip install 'mlflow>=3.4.0'
|
||||||
|
|
||||||
|
- name: Check Storage and Memory Available Before Tests
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: |
|
||||||
|
free -h
|
||||||
|
df -h
|
||||||
|
|
||||||
|
- name: Run MLflow Integration Tests
|
||||||
|
env:
|
||||||
|
MLFLOW_TRACKING_URI: http://localhost:5555
|
||||||
|
run: |
|
||||||
|
uv run --no-sync \
|
||||||
|
pytest -sv \
|
||||||
|
tests/integration/providers/remote/prompts/mlflow/
|
||||||
|
|
||||||
|
- name: Check Storage and Memory Available After Tests
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: |
|
||||||
|
free -h
|
||||||
|
df -h
|
||||||
|
|
||||||
|
- name: Write MLflow logs to file
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: |
|
||||||
|
docker logs mlflow > mlflow.log 2>&1 || true
|
||||||
|
|
||||||
|
- name: Upload all logs to artifacts
|
||||||
|
if: ${{ always() }}
|
||||||
|
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
|
||||||
|
with:
|
||||||
|
name: mlflow-logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.python-version }}
|
||||||
|
path: |
|
||||||
|
*.log
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
|
- name: Stop MLflow container
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: |
|
||||||
|
docker stop mlflow || true
|
||||||
92
docs/docs/providers/prompts/index.mdx
Normal file
92
docs/docs/providers/prompts/index.mdx
Normal file
|
|
@ -0,0 +1,92 @@
|
||||||
|
---
|
||||||
|
sidebar_label: Prompts
|
||||||
|
title: Prompts
|
||||||
|
---
|
||||||
|
|
||||||
|
# Prompts
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This section contains documentation for all available providers for the **prompts** API.
|
||||||
|
|
||||||
|
The Prompts API enables centralized management of prompt templates with versioning, variable handling, and team collaboration capabilities.
|
||||||
|
|
||||||
|
## Available Providers
|
||||||
|
|
||||||
|
### Inline Providers
|
||||||
|
|
||||||
|
Inline providers run in the same process as the Llama Stack server and require no external dependencies:
|
||||||
|
|
||||||
|
- **[inline::reference](inline_reference.mdx)** - Reference implementation using KVStore backend (SQLite, PostgreSQL, etc.)
|
||||||
|
- Zero external dependencies
|
||||||
|
- Supports local SQLite or PostgreSQL storage
|
||||||
|
- Full CRUD operations including deletion
|
||||||
|
- Ideal for local development and single-server deployments
|
||||||
|
|
||||||
|
### Remote Providers
|
||||||
|
|
||||||
|
Remote providers connect to external services for centralized prompt management:
|
||||||
|
|
||||||
|
- **[remote::mlflow](remote_mlflow.mdx)** - MLflow Prompt Registry integration (requires MLflow 3.4+)
|
||||||
|
- Centralized prompt management across teams
|
||||||
|
- Built-in versioning and audit trail
|
||||||
|
- Supports authentication (per-request, config, or environment variables)
|
||||||
|
- Integrates with Databricks and enterprise MLflow deployments
|
||||||
|
- Ideal for team collaboration and production environments
|
||||||
|
|
||||||
|
## Choosing a Provider
|
||||||
|
|
||||||
|
### Use `inline::reference` when:
|
||||||
|
- Developing locally or deploying to a single server
|
||||||
|
- You want zero external dependencies
|
||||||
|
- SQLite or PostgreSQL storage is sufficient
|
||||||
|
- You need full CRUD operations (including deletion)
|
||||||
|
- You prefer simple configuration
|
||||||
|
|
||||||
|
### Use `remote::mlflow` when:
|
||||||
|
- Working in a team environment with multiple users
|
||||||
|
- You need centralized prompt management
|
||||||
|
- Integration with existing MLflow infrastructure
|
||||||
|
- You need authentication and multi-tenant support
|
||||||
|
- Advanced versioning and audit trail capabilities are required
|
||||||
|
|
||||||
|
## Quick Start Examples
|
||||||
|
|
||||||
|
### Using inline::reference
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: local-prompts
|
||||||
|
provider_type: inline::reference
|
||||||
|
config:
|
||||||
|
run_config:
|
||||||
|
storage:
|
||||||
|
stores:
|
||||||
|
prompts:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./prompts.db
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using remote::mlflow
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: mlflow-prompts
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: http://localhost:5555
|
||||||
|
experiment_name: llama-stack-prompts
|
||||||
|
auth_credential: ${env.MLFLOW_TRACKING_TOKEN}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Features
|
||||||
|
|
||||||
|
All prompt providers support:
|
||||||
|
- Create and store prompts with version control
|
||||||
|
- Retrieve prompts by ID and version
|
||||||
|
- Update prompts (creates new versions)
|
||||||
|
- List all prompts or versions of a specific prompt
|
||||||
|
- Set default version for a prompt
|
||||||
|
- Automatic variable extraction from `{{ variable }}` templates
|
||||||
|
|
||||||
|
For detailed documentation on each provider, see the individual provider pages linked above.
|
||||||
496
docs/docs/providers/prompts/inline_reference.mdx
Normal file
496
docs/docs/providers/prompts/inline_reference.mdx
Normal file
|
|
@ -0,0 +1,496 @@
|
||||||
|
---
|
||||||
|
description: |
|
||||||
|
Reference implementation of the Prompts API using KVStore backend (SQLite, PostgreSQL, etc.)
|
||||||
|
for centralized prompt management with versioning support. This is the default provider for
|
||||||
|
prompts that works without external dependencies.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
The Reference Prompts Provider supports:
|
||||||
|
- Create and store prompts with automatic versioning
|
||||||
|
- Retrieve prompts by ID and version
|
||||||
|
- Update prompts (creates new immutable versions)
|
||||||
|
- Delete prompts and their versions
|
||||||
|
- List all prompts or all versions of a specific prompt
|
||||||
|
- Set default version for a prompt
|
||||||
|
- Automatic variable extraction from templates
|
||||||
|
- Storage in SQLite, PostgreSQL, or other KVStore backends
|
||||||
|
|
||||||
|
## Key Capabilities
|
||||||
|
- **Zero Dependencies**: No external services required, runs in-process
|
||||||
|
- **Flexible Storage**: Supports SQLite (default), PostgreSQL, and other KVStore backends
|
||||||
|
- **Version Control**: Immutable versioning ensures prompt history is preserved
|
||||||
|
- **Default Version Management**: Easily switch between prompt versions
|
||||||
|
- **Variable Auto-Extraction**: Automatically detects `{{ variable }}` placeholders
|
||||||
|
- **Full CRUD Support**: Unlike remote providers, supports deletion of prompts
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use Reference Prompts Provider in your Llama Stack project:
|
||||||
|
|
||||||
|
1. Configure your Llama Stack project with the inline::reference provider
|
||||||
|
2. Optionally configure storage backend (defaults to SQLite)
|
||||||
|
3. Start creating and managing prompts
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Configure Llama Stack
|
||||||
|
|
||||||
|
**Basic configuration with SQLite** (default):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: reference-prompts
|
||||||
|
provider_type: inline::reference
|
||||||
|
config:
|
||||||
|
run_config:
|
||||||
|
storage:
|
||||||
|
stores:
|
||||||
|
prompts:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./prompts.db
|
||||||
|
```
|
||||||
|
|
||||||
|
**With PostgreSQL**:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: postgres-prompts
|
||||||
|
provider_type: inline::reference
|
||||||
|
config:
|
||||||
|
run_config:
|
||||||
|
storage:
|
||||||
|
stores:
|
||||||
|
prompts:
|
||||||
|
type: postgres
|
||||||
|
url: postgresql://user:pass@localhost/llama_stack
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Use the Prompts API
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
client = LlamaStackClient(base_url="http://localhost:5000")
|
||||||
|
|
||||||
|
# Create a prompt
|
||||||
|
prompt = client.prompts.create(
|
||||||
|
prompt="Summarize the following text in {{ num_sentences }} sentences:\n\n{{ text }}",
|
||||||
|
variables=["num_sentences", "text"]
|
||||||
|
)
|
||||||
|
print(f"Created prompt: {prompt.prompt_id} (v{prompt.version})")
|
||||||
|
|
||||||
|
# Retrieve prompt
|
||||||
|
retrieved = client.prompts.get(prompt_id=prompt.prompt_id)
|
||||||
|
print(f"Retrieved: {retrieved.prompt}")
|
||||||
|
|
||||||
|
# Update prompt (creates version 2)
|
||||||
|
updated = client.prompts.update(
|
||||||
|
prompt_id=prompt.prompt_id,
|
||||||
|
prompt="Summarize in exactly {{ num_sentences }} sentences:\n\n{{ text }}",
|
||||||
|
version=1,
|
||||||
|
set_as_default=True
|
||||||
|
)
|
||||||
|
print(f"Updated to version: {updated.version}")
|
||||||
|
|
||||||
|
# List all prompts
|
||||||
|
prompts = client.prompts.list()
|
||||||
|
print(f"Found {len(prompts.data)} prompts")
|
||||||
|
|
||||||
|
# Delete prompt
|
||||||
|
client.prompts.delete(prompt_id=prompt.prompt_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
sidebar_label: Inline - Reference
|
||||||
|
title: inline::reference
|
||||||
|
---
|
||||||
|
|
||||||
|
# inline::reference
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
Reference implementation of the Prompts API using KVStore backend (SQLite, PostgreSQL, etc.)
|
||||||
|
for centralized prompt management with versioning support. This is the default provider for
|
||||||
|
prompts that works without external dependencies.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
The Reference Prompts Provider supports:
|
||||||
|
- Create and store prompts with automatic versioning
|
||||||
|
- Retrieve prompts by ID and version
|
||||||
|
- Update prompts (creates new immutable versions)
|
||||||
|
- Delete prompts and their versions
|
||||||
|
- List all prompts or all versions of a specific prompt
|
||||||
|
- Set default version for a prompt
|
||||||
|
- Automatic variable extraction from templates
|
||||||
|
- Storage in SQLite, PostgreSQL, or other KVStore backends
|
||||||
|
|
||||||
|
## Key Capabilities
|
||||||
|
- **Zero Dependencies**: No external services required, runs in-process
|
||||||
|
- **Flexible Storage**: Supports SQLite (default), PostgreSQL, and other KVStore backends
|
||||||
|
- **Version Control**: Immutable versioning ensures prompt history is preserved
|
||||||
|
- **Default Version Management**: Easily switch between prompt versions
|
||||||
|
- **Variable Auto-Extraction**: Automatically detects `{{ variable }}` placeholders
|
||||||
|
- **Full CRUD Support**: Unlike remote providers, supports deletion of prompts
|
||||||
|
|
||||||
|
## Configuration Examples
|
||||||
|
|
||||||
|
### SQLite (Local Development)
|
||||||
|
|
||||||
|
For local development with filesystem storage:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: local-prompts
|
||||||
|
provider_type: inline::reference
|
||||||
|
config:
|
||||||
|
run_config:
|
||||||
|
storage:
|
||||||
|
stores:
|
||||||
|
prompts:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./prompts.db
|
||||||
|
```
|
||||||
|
|
||||||
|
### PostgreSQL (Production)
|
||||||
|
|
||||||
|
For production with PostgreSQL:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: prod-prompts
|
||||||
|
provider_type: inline::reference
|
||||||
|
config:
|
||||||
|
run_config:
|
||||||
|
storage:
|
||||||
|
stores:
|
||||||
|
prompts:
|
||||||
|
type: postgres
|
||||||
|
url: ${env.DATABASE_URL}
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Explicit Backend Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: reference-prompts
|
||||||
|
provider_type: inline::reference
|
||||||
|
config:
|
||||||
|
run_config:
|
||||||
|
storage:
|
||||||
|
backends:
|
||||||
|
kv_default:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./data/prompts.db
|
||||||
|
stores:
|
||||||
|
prompts:
|
||||||
|
backend: kv_default
|
||||||
|
namespace: prompts
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Create Prompt
|
||||||
|
|
||||||
|
Creates a new prompt (version 1):
|
||||||
|
|
||||||
|
```python
|
||||||
|
prompt = client.prompts.create(
|
||||||
|
prompt="You are a {{ role }} assistant. {{ instruction }}",
|
||||||
|
variables=["role", "instruction"] # Optional - auto-extracted if omitted
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Auto-extraction**: If `variables` is not provided, the provider automatically extracts variables from `{{ variable }}` placeholders.
|
||||||
|
|
||||||
|
### Retrieve Prompt
|
||||||
|
|
||||||
|
Get a prompt by ID (retrieves default version):
|
||||||
|
|
||||||
|
```python
|
||||||
|
prompt = client.prompts.get(prompt_id="pmpt_abc123...")
|
||||||
|
```
|
||||||
|
|
||||||
|
Get a specific version:
|
||||||
|
|
||||||
|
```python
|
||||||
|
prompt = client.prompts.get(prompt_id="pmpt_abc123...", version=2)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Update Prompt
|
||||||
|
|
||||||
|
Creates a new version of an existing prompt:
|
||||||
|
|
||||||
|
```python
|
||||||
|
updated = client.prompts.update(
|
||||||
|
prompt_id="pmpt_abc123...",
|
||||||
|
prompt="Updated template with {{ variable }}",
|
||||||
|
version=1, # Must be the latest version
|
||||||
|
set_as_default=True # Make this the new default
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important**: You must provide the current latest version number. The update creates a new version (e.g., version 2).
|
||||||
|
|
||||||
|
### Delete Prompt
|
||||||
|
|
||||||
|
Delete a prompt and all its versions:
|
||||||
|
|
||||||
|
```python
|
||||||
|
client.prompts.delete(prompt_id="pmpt_abc123...")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: This operation is permanent and deletes all versions of the prompt.
|
||||||
|
|
||||||
|
### List Prompts
|
||||||
|
|
||||||
|
List all prompts (returns default versions only):
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = client.prompts.list()
|
||||||
|
for prompt in response.data:
|
||||||
|
print(f"{prompt.prompt_id}: v{prompt.version} (default)")
|
||||||
|
```
|
||||||
|
|
||||||
|
### List Prompt Versions
|
||||||
|
|
||||||
|
List all versions of a specific prompt:
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = client.prompts.list_versions(prompt_id="pmpt_abc123...")
|
||||||
|
for prompt in response.data:
|
||||||
|
default = " (default)" if prompt.is_default else ""
|
||||||
|
print(f"Version {prompt.version}{default}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Set Default Version
|
||||||
|
|
||||||
|
Change which version is the default:
|
||||||
|
|
||||||
|
```python
|
||||||
|
client.prompts.set_default_version(
|
||||||
|
prompt_id="pmpt_abc123...",
|
||||||
|
version=2
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Version Management
|
||||||
|
|
||||||
|
The Reference Prompts Provider implements immutable versioning:
|
||||||
|
|
||||||
|
1. **Create**: Creates version 1
|
||||||
|
2. **Update**: Creates a new version (2, 3, 4, ...)
|
||||||
|
3. **Default**: One version is marked as default
|
||||||
|
4. **History**: All versions are preserved and retrievable
|
||||||
|
5. **Delete**: Can delete all versions at once
|
||||||
|
|
||||||
|
```
|
||||||
|
pmpt_abc123
|
||||||
|
├── Version 1 (Original)
|
||||||
|
├── Version 2 (Updated)
|
||||||
|
└── Version 3 (Latest, Default) <- Current default version
|
||||||
|
```
|
||||||
|
|
||||||
|
## Storage Backends
|
||||||
|
|
||||||
|
The reference provider uses Llama Stack's KVStore abstraction, which supports multiple backends:
|
||||||
|
|
||||||
|
### SQLite (Default)
|
||||||
|
|
||||||
|
Best for:
|
||||||
|
- Local development
|
||||||
|
- Single-server deployments
|
||||||
|
- Embedded applications
|
||||||
|
- Testing
|
||||||
|
|
||||||
|
Limitations:
|
||||||
|
- Not suitable for high-concurrency scenarios
|
||||||
|
- No built-in replication
|
||||||
|
|
||||||
|
### PostgreSQL
|
||||||
|
|
||||||
|
Best for:
|
||||||
|
- Production deployments
|
||||||
|
- Multi-server setups
|
||||||
|
- High availability requirements
|
||||||
|
- Team collaboration
|
||||||
|
|
||||||
|
Advantages:
|
||||||
|
- Supports concurrent access
|
||||||
|
- Built-in replication and backups
|
||||||
|
- Scalable and robust
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### 1. Choose Appropriate Storage
|
||||||
|
|
||||||
|
**Development**:
|
||||||
|
```yaml
|
||||||
|
# Use SQLite for local development
|
||||||
|
storage:
|
||||||
|
stores:
|
||||||
|
prompts:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./dev-prompts.db
|
||||||
|
```
|
||||||
|
|
||||||
|
**Production**:
|
||||||
|
```yaml
|
||||||
|
# Use PostgreSQL for production
|
||||||
|
storage:
|
||||||
|
stores:
|
||||||
|
prompts:
|
||||||
|
type: postgres
|
||||||
|
url: ${env.DATABASE_URL}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Backup Your Data
|
||||||
|
|
||||||
|
For SQLite:
|
||||||
|
```bash
|
||||||
|
# Backup SQLite database
|
||||||
|
cp prompts.db prompts.db.backup
|
||||||
|
```
|
||||||
|
|
||||||
|
For PostgreSQL:
|
||||||
|
```bash
|
||||||
|
# Backup PostgreSQL database
|
||||||
|
pg_dump llama_stack > backup.sql
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Version Management
|
||||||
|
|
||||||
|
- Always retrieve latest version before updating
|
||||||
|
- Use `set_as_default=True` when updating to make new version active
|
||||||
|
- Keep version history for audit trail
|
||||||
|
- Use deletion sparingly (consider archiving instead)
|
||||||
|
|
||||||
|
### 4. Auto-Extract Variables
|
||||||
|
|
||||||
|
Let the provider auto-extract variables to avoid validation errors:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Recommended
|
||||||
|
prompt = client.prompts.create(
|
||||||
|
prompt="Summarize {{ text }} in {{ format }}"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Use Meaningful Templates
|
||||||
|
|
||||||
|
Include context in your templates:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Good
|
||||||
|
prompt = """You are a {{ role }} assistant specialized in {{ domain }}.
|
||||||
|
|
||||||
|
Task: {{ task }}
|
||||||
|
|
||||||
|
Output format: {{ format }}"""
|
||||||
|
|
||||||
|
# Less clear
|
||||||
|
prompt = "Do {{ task }} as {{ role }}"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Database Connection Errors
|
||||||
|
|
||||||
|
**Error**: Failed to connect to database
|
||||||
|
|
||||||
|
**Solutions**:
|
||||||
|
1. Verify database URL is correct
|
||||||
|
2. Ensure database server is running (for PostgreSQL)
|
||||||
|
3. Check file permissions (for SQLite)
|
||||||
|
4. Verify network connectivity (for remote databases)
|
||||||
|
|
||||||
|
### Version Mismatch Error
|
||||||
|
|
||||||
|
**Error**: `Version X is not the latest version. Use latest version Y to update.`
|
||||||
|
|
||||||
|
**Cause**: Attempting to update an outdated version
|
||||||
|
|
||||||
|
**Solution**: Always use the latest version number when updating:
|
||||||
|
```python
|
||||||
|
# Get latest version
|
||||||
|
versions = client.prompts.list_versions(prompt_id)
|
||||||
|
latest_version = max(v.version for v in versions.data)
|
||||||
|
|
||||||
|
# Use latest version for update
|
||||||
|
client.prompts.update(prompt_id=prompt_id, version=latest_version, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Variable Validation Error
|
||||||
|
|
||||||
|
**Error**: `Template contains undeclared variables: ['var2']`
|
||||||
|
|
||||||
|
**Cause**: Template has `{{ var2 }}` but `variables` list doesn't include it
|
||||||
|
|
||||||
|
**Solution**: Either add missing variable or let the provider auto-extract:
|
||||||
|
```python
|
||||||
|
# Option 1: Add missing variable
|
||||||
|
client.prompts.create(
|
||||||
|
prompt="Template with {{ var1 }} and {{ var2 }}",
|
||||||
|
variables=["var1", "var2"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Option 2: Let provider auto-extract (recommended)
|
||||||
|
client.prompts.create(
|
||||||
|
prompt="Template with {{ var1 }} and {{ var2 }}"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Prompt Not Found
|
||||||
|
|
||||||
|
**Error**: `Prompt pmpt_abc123... not found`
|
||||||
|
|
||||||
|
**Possible causes**:
|
||||||
|
1. Prompt ID is incorrect
|
||||||
|
2. Prompt was deleted
|
||||||
|
3. Wrong database or storage backend
|
||||||
|
|
||||||
|
**Solution**: Verify prompt exists using `list()` method
|
||||||
|
|
||||||
|
## Migration Guide
|
||||||
|
|
||||||
|
### Migrating from Core Implementation
|
||||||
|
|
||||||
|
If you're upgrading from an older Llama Stack version where prompts were in `core/prompts`:
|
||||||
|
|
||||||
|
**Old code** (still works):
|
||||||
|
```python
|
||||||
|
from llama_stack.core.prompts import PromptServiceConfig, PromptServiceImpl
|
||||||
|
```
|
||||||
|
|
||||||
|
**New code** (recommended):
|
||||||
|
```python
|
||||||
|
from llama_stack.providers.inline.prompts.reference import ReferencePromptsConfig, PromptServiceImpl
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: Backward compatibility is maintained. Old imports still work.
|
||||||
|
|
||||||
|
### Data Migration
|
||||||
|
|
||||||
|
No data migration needed when upgrading:
|
||||||
|
- Same KVStore backend is used
|
||||||
|
- Existing prompts remain accessible
|
||||||
|
- Configuration structure is compatible
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `run_config` | `StackRunConfig` | Yes | | Stack run configuration containing storage configuration for KVStore |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
run_config:
|
||||||
|
storage:
|
||||||
|
backends:
|
||||||
|
kv_default:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./prompts.db
|
||||||
|
stores:
|
||||||
|
prompts:
|
||||||
|
backend: kv_default
|
||||||
|
namespace: prompts
|
||||||
|
```
|
||||||
751
docs/docs/providers/prompts/remote_mlflow.mdx
Normal file
751
docs/docs/providers/prompts/remote_mlflow.mdx
Normal file
|
|
@ -0,0 +1,751 @@
|
||||||
|
---
|
||||||
|
description: |
|
||||||
|
[MLflow](https://mlflow.org/) is a remote provider for centralized prompt management and versioning
|
||||||
|
using MLflow's Prompt Registry (available in MLflow 3.4+). It allows you to store, version, and manage
|
||||||
|
prompts in a centralized MLflow server, enabling team collaboration and prompt lifecycle management.
|
||||||
|
|
||||||
|
See [MLflow's documentation](https://mlflow.org/docs/latest/prompts.html) for more details about MLflow Prompt Registry.
|
||||||
|
|
||||||
|
sidebar_label: Remote - MLflow
|
||||||
|
title: remote::mlflow
|
||||||
|
---
|
||||||
|
|
||||||
|
# remote::mlflow
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
[MLflow](https://mlflow.org/) is a remote provider for centralized prompt management and versioning
|
||||||
|
using MLflow's Prompt Registry (available in MLflow 3.4+). It allows you to store, version, and manage
|
||||||
|
prompts in a centralized MLflow server, enabling team collaboration and prompt lifecycle management.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
MLflow Prompts Provider supports:
|
||||||
|
- Create and store prompts with automatic versioning
|
||||||
|
- Retrieve prompts by ID and version
|
||||||
|
- Update prompts (creates new immutable versions)
|
||||||
|
- List all prompts or all versions of a specific prompt
|
||||||
|
- Set default version for a prompt
|
||||||
|
- Automatic variable extraction from templates
|
||||||
|
- Metadata storage and retrieval
|
||||||
|
- Centralized prompt management across teams
|
||||||
|
|
||||||
|
## Key Capabilities
|
||||||
|
- **Version Control**: Immutable versioning ensures prompt history is preserved
|
||||||
|
- **Default Version Management**: Easily switch between prompt versions
|
||||||
|
- **Variable Auto-Extraction**: Automatically detects `{{ variable }}` placeholders
|
||||||
|
- **Metadata Tags**: Stores Llama Stack metadata for seamless integration
|
||||||
|
- **Team Collaboration**: Centralized MLflow server enables multi-user access
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use MLflow Prompts Provider in your Llama Stack project:
|
||||||
|
|
||||||
|
1. Install MLflow 3.4 or later
|
||||||
|
2. Start an MLflow server (local or remote)
|
||||||
|
3. Configure your Llama Stack project to use the MLflow provider
|
||||||
|
4. Start creating and managing prompts
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Install MLflow using pip or uv:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install 'mlflow>=3.4.0'
|
||||||
|
# or
|
||||||
|
uv pip install 'mlflow>=3.4.0'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Start MLflow Server
|
||||||
|
|
||||||
|
**Local server** (for development):
|
||||||
|
```bash
|
||||||
|
mlflow server --host 127.0.0.1 --port 5555
|
||||||
|
```
|
||||||
|
|
||||||
|
**Remote server** (for production):
|
||||||
|
```bash
|
||||||
|
mlflow server --host 0.0.0.0 --port 5000 --backend-store-uri postgresql://user:pass@host/db
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Configure Llama Stack
|
||||||
|
|
||||||
|
Add to your Llama Stack configuration:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: mlflow-prompts
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: http://localhost:5555
|
||||||
|
experiment_name: llama-stack-prompts
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Use the Prompts API
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
client = LlamaStackClient(base_url="http://localhost:5000")
|
||||||
|
|
||||||
|
# Create a prompt
|
||||||
|
prompt = client.prompts.create(
|
||||||
|
prompt="Summarize the following text in {{ num_sentences }} sentences:\n\n{{ text }}",
|
||||||
|
variables=["num_sentences", "text"]
|
||||||
|
)
|
||||||
|
print(f"Created prompt: {prompt.prompt_id} (v{prompt.version})")
|
||||||
|
|
||||||
|
# Retrieve prompt
|
||||||
|
retrieved = client.prompts.get(prompt_id=prompt.prompt_id)
|
||||||
|
print(f"Retrieved: {retrieved.prompt}")
|
||||||
|
|
||||||
|
# Update prompt (creates version 2)
|
||||||
|
updated = client.prompts.update(
|
||||||
|
prompt_id=prompt.prompt_id,
|
||||||
|
prompt="Summarize in exactly {{ num_sentences }} sentences:\n\n{{ text }}",
|
||||||
|
version=1,
|
||||||
|
set_as_default=True
|
||||||
|
)
|
||||||
|
print(f"Updated to version: {updated.version}")
|
||||||
|
|
||||||
|
# List all prompts
|
||||||
|
prompts = client.prompts.list()
|
||||||
|
print(f"Found {len(prompts.data)} prompts")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Examples
|
||||||
|
|
||||||
|
### Local Development
|
||||||
|
|
||||||
|
For local development with filesystem storage:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: mlflow-local
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: http://localhost:5555
|
||||||
|
experiment_name: dev-prompts
|
||||||
|
timeout_seconds: 30
|
||||||
|
```
|
||||||
|
|
||||||
|
### Remote MLflow Server
|
||||||
|
|
||||||
|
For production with a remote MLflow server:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: mlflow-production
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: ${env.MLFLOW_TRACKING_URI}
|
||||||
|
experiment_name: production-prompts
|
||||||
|
timeout_seconds: 60
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced Configuration
|
||||||
|
|
||||||
|
With custom settings:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: mlflow-custom
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: https://mlflow.example.com
|
||||||
|
experiment_name: team-prompts
|
||||||
|
timeout_seconds: 45
|
||||||
|
```
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
The MLflow provider supports three authentication methods with the following precedence (highest to lowest):
|
||||||
|
|
||||||
|
1. **Per-Request Provider Data** (via headers)
|
||||||
|
2. **Configuration Auth Credential** (in config file)
|
||||||
|
3. **Environment Variables** (MLflow defaults)
|
||||||
|
|
||||||
|
### Method 1: Per-Request Provider Data (Recommended for Multi-Tenant)
|
||||||
|
|
||||||
|
For multi-tenant deployments where each user has their own credentials:
|
||||||
|
|
||||||
|
**Configuration**:
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: mlflow-prompts
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: http://mlflow.company.com
|
||||||
|
experiment_name: production-prompts
|
||||||
|
# No auth_credential - use per-request tokens
|
||||||
|
```
|
||||||
|
|
||||||
|
**Client Usage**:
|
||||||
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
client = LlamaStackClient(base_url="http://localhost:5000")
|
||||||
|
|
||||||
|
# User 1 with their own token
|
||||||
|
prompts_user1 = client.prompts.list(
|
||||||
|
extra_headers={
|
||||||
|
"x-llamastack-provider-data": '{"mlflow_api_token": "user1-token"}'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# User 2 with their own token
|
||||||
|
prompts_user2 = client.prompts.list(
|
||||||
|
extra_headers={
|
||||||
|
"x-llamastack-provider-data": '{"mlflow_api_token": "user2-token"}'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits**:
|
||||||
|
- Per-user authentication and authorization
|
||||||
|
- No shared credentials
|
||||||
|
- Ideal for SaaS deployments
|
||||||
|
- Supports user-specific MLflow experiments
|
||||||
|
|
||||||
|
### Method 2: Configuration Auth Credential (Server-Level)
|
||||||
|
|
||||||
|
For server-level authentication where all requests use the same credentials:
|
||||||
|
|
||||||
|
**Using Environment Variable** (recommended):
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: mlflow-prompts
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: http://mlflow.company.com
|
||||||
|
experiment_name: production-prompts
|
||||||
|
auth_credential: ${env.MLFLOW_TRACKING_TOKEN}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Using Direct Value** (not recommended for production):
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: mlflow-prompts
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: http://mlflow.company.com
|
||||||
|
experiment_name: production-prompts
|
||||||
|
auth_credential: "mlflow-server-token"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Client Usage**:
|
||||||
|
```python
|
||||||
|
# No extra headers needed - server handles authentication
|
||||||
|
client = LlamaStackClient(base_url="http://localhost:5000")
|
||||||
|
prompts = client.prompts.list()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits**:
|
||||||
|
- Simple configuration
|
||||||
|
- Single point of credential management
|
||||||
|
- Good for single-tenant deployments
|
||||||
|
|
||||||
|
### Method 3: Environment Variables (MLflow Default)
|
||||||
|
|
||||||
|
MLflow reads standard environment variables automatically:
|
||||||
|
|
||||||
|
**Set before starting Llama Stack**:
|
||||||
|
```bash
|
||||||
|
export MLFLOW_TRACKING_TOKEN="your-token"
|
||||||
|
export MLFLOW_TRACKING_USERNAME="user" # Optional: Basic auth
|
||||||
|
export MLFLOW_TRACKING_PASSWORD="pass" # Optional: Basic auth
|
||||||
|
llama stack run my-config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
**Configuration** (no auth_credential needed):
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: mlflow-prompts
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: http://mlflow.company.com
|
||||||
|
experiment_name: production-prompts
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits**:
|
||||||
|
- Standard MLflow behavior
|
||||||
|
- No configuration changes needed
|
||||||
|
- Good for containerized deployments
|
||||||
|
|
||||||
|
### Databricks Authentication
|
||||||
|
|
||||||
|
For Databricks-managed MLflow:
|
||||||
|
|
||||||
|
**Configuration**:
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: databricks-prompts
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: databricks
|
||||||
|
# Or with workspace URL:
|
||||||
|
# mlflow_tracking_uri: databricks://profile-name
|
||||||
|
experiment_name: /Shared/llama-stack-prompts
|
||||||
|
auth_credential: ${env.DATABRICKS_TOKEN}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Environment Setup**:
|
||||||
|
```bash
|
||||||
|
export DATABRICKS_TOKEN="dapi..."
|
||||||
|
export DATABRICKS_HOST="https://your-workspace.cloud.databricks.com"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Client Usage**:
|
||||||
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
client = LlamaStackClient(base_url="http://localhost:5000")
|
||||||
|
|
||||||
|
# Create prompt in Databricks MLflow
|
||||||
|
prompt = client.prompts.create(
|
||||||
|
prompt="Analyze {{ topic }} with focus on {{ aspect }}",
|
||||||
|
variables=["topic", "aspect"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# View in Databricks UI:
|
||||||
|
# https://workspace.cloud.databricks.com/#mlflow/experiments/<experiment-id>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Enterprise MLflow with Authentication
|
||||||
|
|
||||||
|
Example for enterprise MLflow server with API key authentication:
|
||||||
|
|
||||||
|
**Configuration**:
|
||||||
|
```yaml
|
||||||
|
prompts:
|
||||||
|
- provider_id: enterprise-mlflow
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: https://mlflow.enterprise.com
|
||||||
|
experiment_name: production-prompts
|
||||||
|
auth_credential: ${env.MLFLOW_API_KEY}
|
||||||
|
timeout_seconds: 60
|
||||||
|
```
|
||||||
|
|
||||||
|
**Client Usage**:
|
||||||
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
# Option A: Use server's configured credential
|
||||||
|
client = LlamaStackClient(base_url="http://localhost:5000")
|
||||||
|
prompt = client.prompts.create(
|
||||||
|
prompt="Classify sentiment: {{ text }}",
|
||||||
|
variables=["text"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Option B: Override with per-request credential
|
||||||
|
prompt = client.prompts.create(
|
||||||
|
prompt="Classify sentiment: {{ text }}",
|
||||||
|
variables=["text"],
|
||||||
|
extra_headers={
|
||||||
|
"x-llamastack-provider-data": '{"mlflow_api_token": "user-specific-key"}'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Authentication Precedence
|
||||||
|
|
||||||
|
When multiple authentication methods are configured, the provider uses this precedence:
|
||||||
|
|
||||||
|
1. **Per-request provider data** (from `x-llamastack-provider-data` header)
|
||||||
|
- Highest priority
|
||||||
|
- Overrides all other methods
|
||||||
|
- Used for multi-tenant scenarios
|
||||||
|
|
||||||
|
2. **Configuration auth_credential** (from config file)
|
||||||
|
- Medium priority
|
||||||
|
- Fallback if no provider data header
|
||||||
|
- Good for server-level auth
|
||||||
|
|
||||||
|
3. **Environment variables** (MLflow standard)
|
||||||
|
- Lowest priority
|
||||||
|
- Used if no other credentials provided
|
||||||
|
- Standard MLflow behavior
|
||||||
|
|
||||||
|
**Example showing precedence**:
|
||||||
|
```yaml
|
||||||
|
# Config file
|
||||||
|
prompts:
|
||||||
|
- provider_id: mlflow
|
||||||
|
provider_type: remote::mlflow
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: http://mlflow.company.com
|
||||||
|
auth_credential: ${env.MLFLOW_TRACKING_TOKEN} # Fallback
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Environment variable
|
||||||
|
export MLFLOW_TRACKING_TOKEN="server-token" # Lowest priority
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Client code
|
||||||
|
client.prompts.create(
|
||||||
|
prompt="Test",
|
||||||
|
extra_headers={
|
||||||
|
# This takes precedence over config and env vars
|
||||||
|
"x-llamastack-provider-data": '{"mlflow_api_token": "user-token"}'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Security Best Practices
|
||||||
|
|
||||||
|
1. **Never hardcode tokens** in configuration files:
|
||||||
|
```yaml
|
||||||
|
# Bad - hardcoded credential
|
||||||
|
auth_credential: "my-secret-token"
|
||||||
|
|
||||||
|
# Good - use environment variable
|
||||||
|
auth_credential: ${env.MLFLOW_TRACKING_TOKEN}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Use per-request credentials** for multi-tenant deployments:
|
||||||
|
```python
|
||||||
|
# Good - each user provides their own token
|
||||||
|
headers = {
|
||||||
|
"x-llamastack-provider-data": f'{{"mlflow_api_token": "{user_token}"}}'
|
||||||
|
}
|
||||||
|
client.prompts.list(extra_headers=headers)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Rotate credentials regularly** in production environments
|
||||||
|
|
||||||
|
4. **Use HTTPS** for MLflow tracking URI in production:
|
||||||
|
```yaml
|
||||||
|
mlflow_tracking_uri: https://mlflow.company.com # Good
|
||||||
|
# Not: http://mlflow.company.com # Bad for production
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Store secrets in secure vaults** (AWS Secrets Manager, HashiCorp Vault, etc.)
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Create Prompt
|
||||||
|
|
||||||
|
Creates a new prompt (version 1) or registers a prompt in MLflow:
|
||||||
|
|
||||||
|
```python
|
||||||
|
prompt = client.prompts.create(
|
||||||
|
prompt="You are a {{ role }} assistant. {{ instruction }}",
|
||||||
|
variables=["role", "instruction"] # Optional - auto-extracted if omitted
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Auto-extraction**: If `variables` is not provided, the provider automatically extracts variables from `{{ variable }}` placeholders.
|
||||||
|
|
||||||
|
### Retrieve Prompt
|
||||||
|
|
||||||
|
Get a prompt by ID (retrieves default version):
|
||||||
|
|
||||||
|
```python
|
||||||
|
prompt = client.prompts.get(prompt_id="pmpt_abc123...")
|
||||||
|
```
|
||||||
|
|
||||||
|
Get a specific version:
|
||||||
|
|
||||||
|
```python
|
||||||
|
prompt = client.prompts.get(prompt_id="pmpt_abc123...", version=2)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Update Prompt
|
||||||
|
|
||||||
|
Creates a new version of an existing prompt:
|
||||||
|
|
||||||
|
```python
|
||||||
|
updated = client.prompts.update(
|
||||||
|
prompt_id="pmpt_abc123...",
|
||||||
|
prompt="Updated template with {{ variable }}",
|
||||||
|
version=1, # Must be the latest version
|
||||||
|
set_as_default=True # Make this the new default
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important**: You must provide the current latest version number. The update creates a new version (e.g., version 2).
|
||||||
|
|
||||||
|
### List Prompts
|
||||||
|
|
||||||
|
List all prompts (returns default versions only):
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = client.prompts.list()
|
||||||
|
for prompt in response.data:
|
||||||
|
print(f"{prompt.prompt_id}: v{prompt.version} (default)")
|
||||||
|
```
|
||||||
|
|
||||||
|
### List Prompt Versions
|
||||||
|
|
||||||
|
List all versions of a specific prompt:
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = client.prompts.list_versions(prompt_id="pmpt_abc123...")
|
||||||
|
for prompt in response.data:
|
||||||
|
default = " (default)" if prompt.is_default else ""
|
||||||
|
print(f"Version {prompt.version}{default}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Set Default Version
|
||||||
|
|
||||||
|
Change which version is the default:
|
||||||
|
|
||||||
|
```python
|
||||||
|
client.prompts.set_default_version(
|
||||||
|
prompt_id="pmpt_abc123...",
|
||||||
|
version=2
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ID Mapping
|
||||||
|
|
||||||
|
The MLflow provider uses deterministic bidirectional ID mapping:
|
||||||
|
|
||||||
|
- **Llama Stack format**: `pmpt_<48-hex-chars>`
|
||||||
|
- **MLflow format**: `llama_prompt_<48-hex-chars>`
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Llama Stack ID: `pmpt_8c2bf57972a215cd0413e399d03b901cce93815448173c1c`
|
||||||
|
- MLflow name: `llama_prompt_8c2bf57972a215cd0413e399d03b901cce93815448173c1c`
|
||||||
|
|
||||||
|
This ensures prompts created through Llama Stack are easily identifiable in MLflow.
|
||||||
|
|
||||||
|
## Version Management
|
||||||
|
|
||||||
|
MLflow Prompts Provider implements immutable versioning:
|
||||||
|
|
||||||
|
1. **Create**: Creates version 1
|
||||||
|
2. **Update**: Creates a new version (2, 3, 4, ...)
|
||||||
|
3. **Default**: The "default" alias points to the current default version
|
||||||
|
4. **History**: All versions are preserved and retrievable
|
||||||
|
|
||||||
|
```
|
||||||
|
pmpt_abc123
|
||||||
|
├── Version 1 (Original)
|
||||||
|
├── Version 2 (Updated)
|
||||||
|
└── Version 3 (Latest, Default) ← Default alias points here
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### MLflow Server Not Available
|
||||||
|
|
||||||
|
**Error**: `Failed to connect to MLflow server`
|
||||||
|
|
||||||
|
**Solutions**:
|
||||||
|
1. Verify MLflow server is running: `curl http://localhost:5555/health`
|
||||||
|
2. Check `mlflow_tracking_uri` in configuration
|
||||||
|
3. Ensure network connectivity to remote server
|
||||||
|
4. Check firewall settings
|
||||||
|
|
||||||
|
### Version Mismatch Error
|
||||||
|
|
||||||
|
**Error**: `Version X is not the latest version. Use latest version Y to update.`
|
||||||
|
|
||||||
|
**Cause**: Attempting to update an outdated version
|
||||||
|
|
||||||
|
**Solution**: Always use the latest version number when updating:
|
||||||
|
```python
|
||||||
|
# Get latest version
|
||||||
|
versions = client.prompts.list_versions(prompt_id)
|
||||||
|
latest_version = max(v.version for v in versions.data)
|
||||||
|
|
||||||
|
# Use latest version for update
|
||||||
|
client.prompts.update(prompt_id=prompt_id, version=latest_version, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Variable Validation Error
|
||||||
|
|
||||||
|
**Error**: `Template contains undeclared variables: ['var2']`
|
||||||
|
|
||||||
|
**Cause**: Template has `{{ var2 }}` but `variables` list doesn't include it
|
||||||
|
|
||||||
|
**Solution**: Either add missing variable or let the provider auto-extract:
|
||||||
|
```python
|
||||||
|
# Option 1: Add missing variable
|
||||||
|
client.prompts.create(
|
||||||
|
prompt="Template with {{ var1 }} and {{ var2 }}",
|
||||||
|
variables=["var1", "var2"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Option 2: Let provider auto-extract (recommended)
|
||||||
|
client.prompts.create(
|
||||||
|
prompt="Template with {{ var1 }} and {{ var2 }}"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Timeout Errors
|
||||||
|
|
||||||
|
**Error**: Connection timeout when communicating with MLflow
|
||||||
|
|
||||||
|
**Solutions**:
|
||||||
|
1. Increase `timeout_seconds` in configuration:
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
timeout_seconds: 60 # Default: 30
|
||||||
|
```
|
||||||
|
2. Check network latency to MLflow server
|
||||||
|
3. Verify MLflow server is responsive
|
||||||
|
|
||||||
|
### Prompt Not Found
|
||||||
|
|
||||||
|
**Error**: `Prompt pmpt_abc123... not found`
|
||||||
|
|
||||||
|
**Possible causes**:
|
||||||
|
1. Prompt ID is incorrect
|
||||||
|
2. Prompt was created in a different MLflow server/experiment
|
||||||
|
3. Experiment name mismatch in configuration
|
||||||
|
|
||||||
|
**Solution**: Verify prompt exists in MLflow UI at `http://localhost:5555`
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
### No Deletion Support
|
||||||
|
|
||||||
|
**MLflow does not support deleting prompts or versions**. The `delete_prompt()` method raises `NotImplementedError`.
|
||||||
|
|
||||||
|
**Workaround**: Mark prompts as deprecated using naming conventions or set a different version as default.
|
||||||
|
|
||||||
|
### Experiment Required
|
||||||
|
|
||||||
|
All prompts are stored within an MLflow experiment. The experiment is created automatically if it doesn't exist.
|
||||||
|
|
||||||
|
### ID Format Constraints
|
||||||
|
|
||||||
|
- Prompt IDs must follow the format: `pmpt_<48-hex-chars>`
|
||||||
|
- MLflow names use the prefix: `llama_prompt_`
|
||||||
|
- Manual creation in MLflow with different names won't be recognized
|
||||||
|
|
||||||
|
### Version Numbering
|
||||||
|
|
||||||
|
- Versions are sequential integers (1, 2, 3, ...)
|
||||||
|
- You cannot skip version numbers
|
||||||
|
- You cannot manually set version numbers
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### 1. Use Environment Variables
|
||||||
|
|
||||||
|
Store MLflow URIs in environment variables:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
mlflow_tracking_uri: ${env.MLFLOW_TRACKING_URI:=http://localhost:5555}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Auto-Extract Variables
|
||||||
|
|
||||||
|
Let the provider auto-extract variables to avoid validation errors:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Recommended
|
||||||
|
prompt = client.prompts.create(
|
||||||
|
prompt="Summarize {{ text }} in {{ format }}"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Organize by Experiment
|
||||||
|
|
||||||
|
Use different experiments for different environments:
|
||||||
|
|
||||||
|
- `dev-prompts` for development
|
||||||
|
- `staging-prompts` for staging
|
||||||
|
- `production-prompts` for production
|
||||||
|
|
||||||
|
### 4. Version Management
|
||||||
|
|
||||||
|
- Always retrieve latest version before updating
|
||||||
|
- Use `set_as_default=True` when updating to make new version active
|
||||||
|
- Keep version history for audit trail
|
||||||
|
|
||||||
|
### 5. Use Meaningful Templates
|
||||||
|
|
||||||
|
Include context in your templates:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Good
|
||||||
|
prompt = """You are a {{ role }} assistant specialized in {{ domain }}.
|
||||||
|
|
||||||
|
Task: {{ task }}
|
||||||
|
|
||||||
|
Output format: {{ format }}"""
|
||||||
|
|
||||||
|
# Less clear
|
||||||
|
prompt = "Do {{ task }} as {{ role }}"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Monitor MLflow Server
|
||||||
|
|
||||||
|
- Use MLflow UI to visualize prompts: `http://your-server:5555`
|
||||||
|
- Monitor experiment metrics and prompt versions
|
||||||
|
- Set up alerts for MLflow server health
|
||||||
|
|
||||||
|
## Production Deployment
|
||||||
|
|
||||||
|
### Database Backend
|
||||||
|
|
||||||
|
For production, use a database backend instead of filesystem:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mlflow server \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 5000 \
|
||||||
|
--backend-store-uri postgresql://user:pass@host:5432/mlflow \
|
||||||
|
--default-artifact-root s3://my-bucket/mlflow-artifacts
|
||||||
|
```
|
||||||
|
|
||||||
|
### High Availability
|
||||||
|
|
||||||
|
- Deploy multiple MLflow server instances behind a load balancer
|
||||||
|
- Use managed database (RDS, Cloud SQL, etc.)
|
||||||
|
- Store artifacts in object storage (S3, GCS, Azure Blob)
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
- Enable authentication on MLflow server
|
||||||
|
- Use HTTPS for MLflow tracking URI
|
||||||
|
- Restrict network access with firewall rules
|
||||||
|
- Use IAM roles for cloud deployments
|
||||||
|
|
||||||
|
### Monitoring
|
||||||
|
|
||||||
|
Set up monitoring for:
|
||||||
|
- MLflow server availability
|
||||||
|
- Database connection pool
|
||||||
|
- API response times
|
||||||
|
- Prompt creation/retrieval rates
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
See [MLflow's documentation](https://mlflow.org/docs/latest/prompts.html) for more details about MLflow Prompt Registry.
|
||||||
|
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `mlflow_tracking_uri` | `str` | No | http://localhost:5000 | MLflow tracking server URI |
|
||||||
|
| `mlflow_registry_uri` | `str \| None` | No | None | MLflow model registry URI (defaults to tracking URI if not set) |
|
||||||
|
| `experiment_name` | `str` | No | llama-stack-prompts | MLflow experiment name for storing prompts |
|
||||||
|
| `auth_credential` | `SecretStr \| None` | No | None | MLflow API token for authentication. Can be overridden via provider data header. |
|
||||||
|
| `timeout_seconds` | `int` | No | 30 | Timeout for MLflow API calls (1-300 seconds) |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
**Without authentication** (local development):
|
||||||
|
```yaml
|
||||||
|
mlflow_tracking_uri: http://localhost:5555
|
||||||
|
experiment_name: llama-stack-prompts
|
||||||
|
timeout_seconds: 30
|
||||||
|
```
|
||||||
|
|
||||||
|
**With authentication** (production):
|
||||||
|
```yaml
|
||||||
|
mlflow_tracking_uri: ${env.MLFLOW_TRACKING_URI:=http://localhost:5000}
|
||||||
|
experiment_name: llama-stack-prompts
|
||||||
|
auth_credential: ${env.MLFLOW_TRACKING_TOKEN:=}
|
||||||
|
timeout_seconds: 30
|
||||||
|
```
|
||||||
|
|
@ -4,232 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
"""Core prompts service delegating to inline::reference provider.
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
This module provides backward compatibility by delegating to the
|
||||||
|
inline::reference provider implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.providers.inline.prompts.reference import (
|
||||||
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
PromptServiceImpl,
|
||||||
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
ReferencePromptsConfig,
|
||||||
|
get_adapter_impl,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-export for backward compatibility
|
||||||
|
PromptServiceConfig = ReferencePromptsConfig
|
||||||
|
get_provider_impl = get_adapter_impl
|
||||||
|
|
||||||
class PromptServiceConfig(BaseModel):
|
__all__ = [
|
||||||
"""Configuration for the built-in prompt service.
|
"PromptServiceImpl",
|
||||||
|
"PromptServiceConfig",
|
||||||
:param run_config: Stack run configuration containing distribution info
|
"ReferencePromptsConfig",
|
||||||
"""
|
"get_provider_impl",
|
||||||
|
"get_adapter_impl",
|
||||||
run_config: StackRunConfig
|
]
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
|
|
||||||
"""Get the prompt service implementation."""
|
|
||||||
impl = PromptServiceImpl(config, deps)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
class PromptServiceImpl(Prompts):
|
|
||||||
"""Built-in prompt service implementation using KVStore."""
|
|
||||||
|
|
||||||
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
|
|
||||||
self.config = config
|
|
||||||
self.deps = deps
|
|
||||||
self.kvstore: KVStore
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
# Use prompts store reference from run config
|
|
||||||
prompts_ref = self.config.run_config.storage.stores.prompts
|
|
||||||
if not prompts_ref:
|
|
||||||
raise ValueError("storage.stores.prompts must be configured in run config")
|
|
||||||
self.kvstore = await kvstore_impl(prompts_ref)
|
|
||||||
|
|
||||||
def _get_default_key(self, prompt_id: str) -> str:
|
|
||||||
"""Get the KVStore key that stores the default version number."""
|
|
||||||
return f"prompts:v1:{prompt_id}:default"
|
|
||||||
|
|
||||||
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
|
|
||||||
"""Get the KVStore key for prompt data, returning default version if applicable."""
|
|
||||||
if version:
|
|
||||||
return self._get_version_key(prompt_id, str(version))
|
|
||||||
|
|
||||||
default_key = self._get_default_key(prompt_id)
|
|
||||||
resolved_version = await self.kvstore.get(default_key)
|
|
||||||
if resolved_version is None:
|
|
||||||
raise ValueError(f"Prompt {prompt_id}:default not found")
|
|
||||||
return self._get_version_key(prompt_id, resolved_version)
|
|
||||||
|
|
||||||
def _get_version_key(self, prompt_id: str, version: str) -> str:
|
|
||||||
"""Get the KVStore key for a specific prompt version."""
|
|
||||||
return f"prompts:v1:{prompt_id}:{version}"
|
|
||||||
|
|
||||||
def _get_list_key_prefix(self) -> str:
|
|
||||||
"""Get the key prefix for listing prompts."""
|
|
||||||
return "prompts:v1:"
|
|
||||||
|
|
||||||
def _serialize_prompt(self, prompt: Prompt) -> str:
|
|
||||||
"""Serialize a prompt to JSON string for storage."""
|
|
||||||
return json.dumps(
|
|
||||||
{
|
|
||||||
"prompt_id": prompt.prompt_id,
|
|
||||||
"prompt": prompt.prompt,
|
|
||||||
"version": prompt.version,
|
|
||||||
"variables": prompt.variables or [],
|
|
||||||
"is_default": prompt.is_default,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _deserialize_prompt(self, data: str) -> Prompt:
|
|
||||||
"""Deserialize a prompt from JSON string."""
|
|
||||||
obj = json.loads(data)
|
|
||||||
return Prompt(
|
|
||||||
prompt_id=obj["prompt_id"],
|
|
||||||
prompt=obj["prompt"],
|
|
||||||
version=obj["version"],
|
|
||||||
variables=obj.get("variables", []),
|
|
||||||
is_default=obj.get("is_default", False),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def list_prompts(self) -> ListPromptsResponse:
|
|
||||||
"""List all prompts (default versions only)."""
|
|
||||||
prefix = self._get_list_key_prefix()
|
|
||||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
|
||||||
|
|
||||||
prompts = []
|
|
||||||
for key in keys:
|
|
||||||
if key.endswith(":default"):
|
|
||||||
try:
|
|
||||||
default_version = await self.kvstore.get(key)
|
|
||||||
if default_version:
|
|
||||||
prompt_id = key.replace(prefix, "").replace(":default", "")
|
|
||||||
version_key = self._get_version_key(prompt_id, default_version)
|
|
||||||
data = await self.kvstore.get(version_key)
|
|
||||||
if data:
|
|
||||||
prompt = self._deserialize_prompt(data)
|
|
||||||
prompts.append(prompt)
|
|
||||||
except (json.JSONDecodeError, KeyError):
|
|
||||||
continue
|
|
||||||
|
|
||||||
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
|
|
||||||
return ListPromptsResponse(data=prompts)
|
|
||||||
|
|
||||||
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
|
|
||||||
"""Get a prompt by its identifier and optional version."""
|
|
||||||
key = await self._get_prompt_key(prompt_id, version)
|
|
||||||
data = await self.kvstore.get(key)
|
|
||||||
if data is None:
|
|
||||||
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
|
|
||||||
return self._deserialize_prompt(data)
|
|
||||||
|
|
||||||
async def create_prompt(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
variables: list[str] | None = None,
|
|
||||||
) -> Prompt:
|
|
||||||
"""Create a new prompt."""
|
|
||||||
if variables is None:
|
|
||||||
variables = []
|
|
||||||
|
|
||||||
prompt_obj = Prompt(
|
|
||||||
prompt_id=Prompt.generate_prompt_id(),
|
|
||||||
prompt=prompt,
|
|
||||||
version=1,
|
|
||||||
variables=variables,
|
|
||||||
)
|
|
||||||
|
|
||||||
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
|
|
||||||
data = self._serialize_prompt(prompt_obj)
|
|
||||||
await self.kvstore.set(version_key, data)
|
|
||||||
|
|
||||||
default_key = self._get_default_key(prompt_obj.prompt_id)
|
|
||||||
await self.kvstore.set(default_key, str(prompt_obj.version))
|
|
||||||
|
|
||||||
return prompt_obj
|
|
||||||
|
|
||||||
async def update_prompt(
|
|
||||||
self,
|
|
||||||
prompt_id: str,
|
|
||||||
prompt: str,
|
|
||||||
version: int,
|
|
||||||
variables: list[str] | None = None,
|
|
||||||
set_as_default: bool = True,
|
|
||||||
) -> Prompt:
|
|
||||||
"""Update an existing prompt (increments version)."""
|
|
||||||
if version < 1:
|
|
||||||
raise ValueError("Version must be >= 1")
|
|
||||||
if variables is None:
|
|
||||||
variables = []
|
|
||||||
|
|
||||||
prompt_versions = await self.list_prompt_versions(prompt_id)
|
|
||||||
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
|
|
||||||
|
|
||||||
if version and latest_prompt.version != version:
|
|
||||||
raise ValueError(
|
|
||||||
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
|
|
||||||
)
|
|
||||||
|
|
||||||
current_version = latest_prompt.version if version is None else version
|
|
||||||
new_version = current_version + 1
|
|
||||||
|
|
||||||
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
|
||||||
|
|
||||||
version_key = self._get_version_key(prompt_id, str(new_version))
|
|
||||||
data = self._serialize_prompt(updated_prompt)
|
|
||||||
await self.kvstore.set(version_key, data)
|
|
||||||
|
|
||||||
if set_as_default:
|
|
||||||
await self.set_default_version(prompt_id, new_version)
|
|
||||||
|
|
||||||
return updated_prompt
|
|
||||||
|
|
||||||
async def delete_prompt(self, prompt_id: str) -> None:
|
|
||||||
"""Delete a prompt and all its versions."""
|
|
||||||
await self.get_prompt(prompt_id)
|
|
||||||
|
|
||||||
prefix = f"prompts:v1:{prompt_id}:"
|
|
||||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
await self.kvstore.delete(key)
|
|
||||||
|
|
||||||
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
|
|
||||||
"""List all versions of a specific prompt."""
|
|
||||||
prefix = f"prompts:v1:{prompt_id}:"
|
|
||||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
|
||||||
|
|
||||||
default_version = None
|
|
||||||
prompts = []
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
data = await self.kvstore.get(key)
|
|
||||||
if key.endswith(":default"):
|
|
||||||
default_version = data
|
|
||||||
else:
|
|
||||||
if data:
|
|
||||||
prompt_obj = self._deserialize_prompt(data)
|
|
||||||
prompts.append(prompt_obj)
|
|
||||||
|
|
||||||
if not prompts:
|
|
||||||
raise ValueError(f"Prompt {prompt_id} not found")
|
|
||||||
|
|
||||||
for prompt in prompts:
|
|
||||||
prompt.is_default = str(prompt.version) == default_version
|
|
||||||
|
|
||||||
prompts.sort(key=lambda x: x.version)
|
|
||||||
return ListPromptsResponse(data=prompts)
|
|
||||||
|
|
||||||
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
|
|
||||||
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
|
|
||||||
version_key = self._get_version_key(prompt_id, str(version))
|
|
||||||
data = await self.kvstore.get(version_key)
|
|
||||||
if data is None:
|
|
||||||
raise ValueError(f"Prompt {prompt_id} version {version} not found")
|
|
||||||
|
|
||||||
default_key = self._get_default_key(prompt_id)
|
|
||||||
await self.kvstore.set(default_key, str(version))
|
|
||||||
|
|
||||||
return self._deserialize_prompt(data)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
|
||||||
5
src/llama_stack/providers/inline/prompts/__init__.py
Normal file
5
src/llama_stack/providers/inline/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import ReferencePromptsConfig
|
||||||
|
from .reference import PromptServiceImpl
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: ReferencePromptsConfig, _deps):
|
||||||
|
impl = PromptServiceImpl(config=config, deps=_deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ReferencePromptsConfig", "PromptServiceImpl", "get_adapter_impl"]
|
||||||
21
src/llama_stack/providers/inline/prompts/reference/config.py
Normal file
21
src/llama_stack/providers/inline/prompts/reference/config.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ReferencePromptsConfig(BaseModel):
|
||||||
|
"""Configuration for the built-in reference prompt service.
|
||||||
|
|
||||||
|
This provider stores prompts in the configured KVStore (SQLite, PostgreSQL, etc.)
|
||||||
|
as specified in the run configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
run_config: StackRunConfig = Field(
|
||||||
|
description="Stack run configuration containing storage configuration"
|
||||||
|
)
|
||||||
222
src/llama_stack/providers/inline/prompts/reference/reference.py
Normal file
222
src/llama_stack/providers/inline/prompts/reference/reference.py
Normal file
|
|
@ -0,0 +1,222 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||||
|
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
||||||
|
|
||||||
|
from .config import ReferencePromptsConfig
|
||||||
|
|
||||||
|
|
||||||
|
class PromptServiceImpl(Prompts):
|
||||||
|
"""Reference inline prompt service implementation using KVStore.
|
||||||
|
|
||||||
|
This provider stores prompts in the configured KVStore backend (SQLite, PostgreSQL, etc.)
|
||||||
|
and provides full CRUD operations with versioning support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ReferencePromptsConfig, deps: dict[Any, Any]):
|
||||||
|
self.config = config
|
||||||
|
self.deps = deps
|
||||||
|
self.kvstore: KVStore
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
# Use prompts store reference from run config
|
||||||
|
prompts_ref = self.config.run_config.storage.stores.prompts
|
||||||
|
if not prompts_ref:
|
||||||
|
raise ValueError("storage.stores.prompts must be configured in run config")
|
||||||
|
self.kvstore = await kvstore_impl(prompts_ref)
|
||||||
|
|
||||||
|
def _get_default_key(self, prompt_id: str) -> str:
|
||||||
|
"""Get the KVStore key that stores the default version number."""
|
||||||
|
return f"prompts:v1:{prompt_id}:default"
|
||||||
|
|
||||||
|
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
|
||||||
|
"""Get the KVStore key for prompt data, returning default version if applicable."""
|
||||||
|
if version:
|
||||||
|
return self._get_version_key(prompt_id, str(version))
|
||||||
|
|
||||||
|
default_key = self._get_default_key(prompt_id)
|
||||||
|
resolved_version = await self.kvstore.get(default_key)
|
||||||
|
if resolved_version is None:
|
||||||
|
raise ValueError(f"Prompt {prompt_id}:default not found")
|
||||||
|
return self._get_version_key(prompt_id, resolved_version)
|
||||||
|
|
||||||
|
def _get_version_key(self, prompt_id: str, version: str) -> str:
|
||||||
|
"""Get the KVStore key for a specific prompt version."""
|
||||||
|
return f"prompts:v1:{prompt_id}:{version}"
|
||||||
|
|
||||||
|
def _get_list_key_prefix(self) -> str:
|
||||||
|
"""Get the key prefix for listing prompts."""
|
||||||
|
return "prompts:v1:"
|
||||||
|
|
||||||
|
def _serialize_prompt(self, prompt: Prompt) -> str:
|
||||||
|
"""Serialize a prompt to JSON string for storage."""
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"prompt_id": prompt.prompt_id,
|
||||||
|
"prompt": prompt.prompt,
|
||||||
|
"version": prompt.version,
|
||||||
|
"variables": prompt.variables or [],
|
||||||
|
"is_default": prompt.is_default,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _deserialize_prompt(self, data: str) -> Prompt:
|
||||||
|
"""Deserialize a prompt from JSON string."""
|
||||||
|
obj = json.loads(data)
|
||||||
|
return Prompt(
|
||||||
|
prompt_id=obj["prompt_id"],
|
||||||
|
prompt=obj["prompt"],
|
||||||
|
version=obj["version"],
|
||||||
|
variables=obj.get("variables", []),
|
||||||
|
is_default=obj.get("is_default", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_prompts(self) -> ListPromptsResponse:
|
||||||
|
"""List all prompts (default versions only)."""
|
||||||
|
prefix = self._get_list_key_prefix()
|
||||||
|
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||||
|
|
||||||
|
prompts = []
|
||||||
|
for key in keys:
|
||||||
|
if key.endswith(":default"):
|
||||||
|
try:
|
||||||
|
default_version = await self.kvstore.get(key)
|
||||||
|
if default_version:
|
||||||
|
prompt_id = key.replace(prefix, "").replace(":default", "")
|
||||||
|
version_key = self._get_version_key(prompt_id, default_version)
|
||||||
|
data = await self.kvstore.get(version_key)
|
||||||
|
if data:
|
||||||
|
prompt = self._deserialize_prompt(data)
|
||||||
|
prompts.append(prompt)
|
||||||
|
except (json.JSONDecodeError, KeyError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
|
||||||
|
return ListPromptsResponse(data=prompts)
|
||||||
|
|
||||||
|
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
|
||||||
|
"""Get a prompt by its identifier and optional version."""
|
||||||
|
key = await self._get_prompt_key(prompt_id, version)
|
||||||
|
data = await self.kvstore.get(key)
|
||||||
|
if data is None:
|
||||||
|
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
|
||||||
|
return self._deserialize_prompt(data)
|
||||||
|
|
||||||
|
async def create_prompt(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
variables: list[str] | None = None,
|
||||||
|
) -> Prompt:
|
||||||
|
"""Create a new prompt."""
|
||||||
|
if variables is None:
|
||||||
|
variables = []
|
||||||
|
|
||||||
|
prompt_obj = Prompt(
|
||||||
|
prompt_id=Prompt.generate_prompt_id(),
|
||||||
|
prompt=prompt,
|
||||||
|
version=1,
|
||||||
|
variables=variables,
|
||||||
|
)
|
||||||
|
|
||||||
|
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
|
||||||
|
data = self._serialize_prompt(prompt_obj)
|
||||||
|
await self.kvstore.set(version_key, data)
|
||||||
|
|
||||||
|
default_key = self._get_default_key(prompt_obj.prompt_id)
|
||||||
|
await self.kvstore.set(default_key, str(prompt_obj.version))
|
||||||
|
|
||||||
|
return prompt_obj
|
||||||
|
|
||||||
|
async def update_prompt(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
prompt: str,
|
||||||
|
version: int,
|
||||||
|
variables: list[str] | None = None,
|
||||||
|
set_as_default: bool = True,
|
||||||
|
) -> Prompt:
|
||||||
|
"""Update an existing prompt (increments version)."""
|
||||||
|
if version < 1:
|
||||||
|
raise ValueError("Version must be >= 1")
|
||||||
|
if variables is None:
|
||||||
|
variables = []
|
||||||
|
|
||||||
|
prompt_versions = await self.list_prompt_versions(prompt_id)
|
||||||
|
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
|
||||||
|
|
||||||
|
if version and latest_prompt.version != version:
|
||||||
|
raise ValueError(
|
||||||
|
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
|
||||||
|
)
|
||||||
|
|
||||||
|
current_version = latest_prompt.version if version is None else version
|
||||||
|
new_version = current_version + 1
|
||||||
|
|
||||||
|
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
||||||
|
|
||||||
|
version_key = self._get_version_key(prompt_id, str(new_version))
|
||||||
|
data = self._serialize_prompt(updated_prompt)
|
||||||
|
await self.kvstore.set(version_key, data)
|
||||||
|
|
||||||
|
if set_as_default:
|
||||||
|
await self.set_default_version(prompt_id, new_version)
|
||||||
|
|
||||||
|
return updated_prompt
|
||||||
|
|
||||||
|
async def delete_prompt(self, prompt_id: str) -> None:
|
||||||
|
"""Delete a prompt and all its versions."""
|
||||||
|
await self.get_prompt(prompt_id)
|
||||||
|
|
||||||
|
prefix = f"prompts:v1:{prompt_id}:"
|
||||||
|
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
await self.kvstore.delete(key)
|
||||||
|
|
||||||
|
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
|
||||||
|
"""List all versions of a specific prompt."""
|
||||||
|
prefix = f"prompts:v1:{prompt_id}:"
|
||||||
|
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||||
|
|
||||||
|
default_version = None
|
||||||
|
prompts = []
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
data = await self.kvstore.get(key)
|
||||||
|
if key.endswith(":default"):
|
||||||
|
default_version = data
|
||||||
|
else:
|
||||||
|
if data:
|
||||||
|
prompt_obj = self._deserialize_prompt(data)
|
||||||
|
prompts.append(prompt_obj)
|
||||||
|
|
||||||
|
if not prompts:
|
||||||
|
raise ValueError(f"Prompt {prompt_id} not found")
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
prompt.is_default = str(prompt.version) == default_version
|
||||||
|
|
||||||
|
prompts.sort(key=lambda x: x.version)
|
||||||
|
return ListPromptsResponse(data=prompts)
|
||||||
|
|
||||||
|
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
|
||||||
|
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
|
||||||
|
version_key = self._get_version_key(prompt_id, str(version))
|
||||||
|
data = await self.kvstore.get(version_key)
|
||||||
|
if data is None:
|
||||||
|
raise ValueError(f"Prompt {prompt_id} version {version} not found")
|
||||||
|
|
||||||
|
default_key = self._get_default_key(prompt_id)
|
||||||
|
await self.kvstore.set(default_key, str(version))
|
||||||
|
|
||||||
|
return self._deserialize_prompt(data)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
31
src/llama_stack/providers/registry/prompts.py
Normal file
31
src/llama_stack/providers/registry/prompts.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack_api import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> list[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.prompts,
|
||||||
|
provider_type="inline::reference",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.inline.prompts.reference",
|
||||||
|
config_class="llama_stack.providers.inline.prompts.reference.ReferencePromptsConfig",
|
||||||
|
description="Reference implementation storing prompts in KVStore (SQLite, PostgreSQL, etc.)",
|
||||||
|
),
|
||||||
|
RemoteProviderSpec(
|
||||||
|
api=Api.prompts,
|
||||||
|
adapter_type="mlflow",
|
||||||
|
provider_type="remote::mlflow",
|
||||||
|
pip_packages=["mlflow>=3.4.0"],
|
||||||
|
module="llama_stack.providers.remote.prompts.mlflow",
|
||||||
|
config_class="llama_stack.providers.remote.prompts.mlflow.MLflowPromptsConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.prompts.mlflow.config.MLflowProviderDataValidator",
|
||||||
|
description="MLflow Prompt Registry provider for centralized prompt management and versioning",
|
||||||
|
),
|
||||||
|
]
|
||||||
5
src/llama_stack/providers/remote/prompts/__init__.py
Normal file
5
src/llama_stack/providers/remote/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
17
src/llama_stack/providers/remote/prompts/mlflow/__init__.py
Normal file
17
src/llama_stack/providers/remote/prompts/mlflow/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import MLflowPromptsConfig
|
||||||
|
from .mlflow import MLflowPromptsAdapter
|
||||||
|
|
||||||
|
__all__ = ["MLflowPromptsConfig", "MLflowPromptsAdapter", "get_adapter_impl"]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: MLflowPromptsConfig, _deps):
|
||||||
|
"""Get the MLflow prompts adapter implementation."""
|
||||||
|
impl = MLflowPromptsAdapter(config=config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
105
src/llama_stack/providers/remote/prompts/mlflow/config.py
Normal file
105
src/llama_stack/providers/remote/prompts/mlflow/config.py
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Configuration for MLflow Prompt Registry provider.
|
||||||
|
|
||||||
|
This module defines the configuration schema for integrating Llama Stack
|
||||||
|
with MLflow's Prompt Registry for centralized prompt management and versioning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, SecretStr, field_validator
|
||||||
|
|
||||||
|
from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class MLflowProviderDataValidator(BaseModel):
|
||||||
|
"""Validator for provider data from request headers.
|
||||||
|
|
||||||
|
This allows users to override the MLflow API token per request via
|
||||||
|
the x-llamastack-provider-data header:
|
||||||
|
{"mlflow_api_token": "your-token"}
|
||||||
|
"""
|
||||||
|
|
||||||
|
mlflow_api_token: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="MLflow API token for authentication (overrides config)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MLflowPromptsConfig(BaseModel):
|
||||||
|
"""Configuration for MLflow Prompt Registry provider.
|
||||||
|
|
||||||
|
Credentials can be provided via:
|
||||||
|
1. Per-request provider data header (preferred for security)
|
||||||
|
2. Configuration auth_credential (fallback)
|
||||||
|
3. Environment variables set by MLflow (MLFLOW_TRACKING_TOKEN, etc.)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
mlflow_tracking_uri: MLflow tracking server URI (e.g., http://localhost:5000, databricks)
|
||||||
|
mlflow_registry_uri: MLflow registry URI (optional, defaults to tracking_uri)
|
||||||
|
experiment_name: MLflow experiment name for prompt storage
|
||||||
|
auth_credential: MLflow API token for authentication (optional, can be overridden by provider data)
|
||||||
|
timeout_seconds: Timeout for MLflow API calls in seconds (default: 30)
|
||||||
|
"""
|
||||||
|
|
||||||
|
mlflow_tracking_uri: str = Field(
|
||||||
|
default="http://localhost:5000",
|
||||||
|
description="MLflow tracking server URI (e.g., http://localhost:5000, databricks, databricks://profile)",
|
||||||
|
)
|
||||||
|
mlflow_registry_uri: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="MLflow registry URI (defaults to tracking_uri if not specified)",
|
||||||
|
)
|
||||||
|
experiment_name: str = Field(
|
||||||
|
default="llama-stack-prompts",
|
||||||
|
description="MLflow experiment name for prompt storage and organization",
|
||||||
|
)
|
||||||
|
auth_credential: SecretStr | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="MLflow API token for authentication. Can be overridden via provider data header.",
|
||||||
|
)
|
||||||
|
timeout_seconds: int = Field(
|
||||||
|
default=30,
|
||||||
|
ge=1,
|
||||||
|
le=300,
|
||||||
|
description="Timeout for MLflow API calls in seconds (1-300)",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, mlflow_api_token: str = "${env.MLFLOW_TRACKING_TOKEN:=}", **kwargs) -> dict[str, Any]:
|
||||||
|
"""Generate sample configuration with environment variable substitution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mlflow_api_token: MLflow API token (defaults to MLFLOW_TRACKING_TOKEN env var)
|
||||||
|
**kwargs: Additional configuration overrides
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sample configuration dictionary
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"mlflow_tracking_uri": kwargs.get("mlflow_tracking_uri", "http://localhost:5000"),
|
||||||
|
"experiment_name": kwargs.get("experiment_name", "llama-stack-prompts"),
|
||||||
|
"auth_credential": mlflow_api_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
@field_validator("mlflow_tracking_uri")
|
||||||
|
@classmethod
|
||||||
|
def validate_tracking_uri(cls, v: str) -> str:
|
||||||
|
"""Validate tracking URI is not empty."""
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("mlflow_tracking_uri cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
@field_validator("experiment_name")
|
||||||
|
@classmethod
|
||||||
|
def validate_experiment_name(cls, v: str) -> str:
|
||||||
|
"""Validate experiment name is not empty."""
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("experiment_name cannot be empty")
|
||||||
|
return v.strip()
|
||||||
123
src/llama_stack/providers/remote/prompts/mlflow/mapping.py
Normal file
123
src/llama_stack/providers/remote/prompts/mlflow/mapping.py
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""ID mapping utilities for MLflow Prompt Registry provider.
|
||||||
|
|
||||||
|
This module handles bidirectional mapping between Llama Stack's prompt_id format
|
||||||
|
(pmpt_<48-hex-chars>) and MLflow's name-based system (llama_prompt_<hex>).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class PromptIDMapper:
|
||||||
|
"""Handle bidirectional mapping between Llama Stack IDs and MLflow names.
|
||||||
|
|
||||||
|
Llama Stack uses prompt IDs in format: pmpt_<48-hex-chars>
|
||||||
|
MLflow uses string names, so we map to: llama_prompt_<48-hex-chars>
|
||||||
|
|
||||||
|
This ensures:
|
||||||
|
- Deterministic mapping (same ID always maps to same name)
|
||||||
|
- Reversible (can recover original ID from MLflow name)
|
||||||
|
- Unique (different IDs map to different names)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Regex pattern for Llama Stack prompt_id validation
|
||||||
|
PROMPT_ID_PATTERN = re.compile(r"^pmpt_[0-9a-f]{48}$")
|
||||||
|
|
||||||
|
# Prefix for MLflow prompt names managed by Llama Stack
|
||||||
|
MLFLOW_NAME_PREFIX = "llama_prompt_"
|
||||||
|
|
||||||
|
def to_mlflow_name(self, prompt_id: str) -> str:
|
||||||
|
"""Convert Llama Stack prompt_id to MLflow prompt name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id: Llama Stack prompt ID (format: pmpt_<48-hex-chars>)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MLflow prompt name (format: llama_prompt_<48-hex-chars>)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prompt_id format is invalid
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> mapper = PromptIDMapper()
|
||||||
|
>>> mapper.to_mlflow_name("pmpt_a1b2c3d4e5f6...")
|
||||||
|
"llama_prompt_a1b2c3d4e5f6..."
|
||||||
|
"""
|
||||||
|
if not self.PROMPT_ID_PATTERN.match(prompt_id):
|
||||||
|
raise ValueError(f"Invalid prompt_id format: {prompt_id}. Expected format: pmpt_<48-hex-chars>")
|
||||||
|
|
||||||
|
# Extract hex part (after "pmpt_" prefix)
|
||||||
|
hex_part = prompt_id.split("pmpt_")[1]
|
||||||
|
|
||||||
|
# Create MLflow name
|
||||||
|
return f"{self.MLFLOW_NAME_PREFIX}{hex_part}"
|
||||||
|
|
||||||
|
def to_llama_id(self, mlflow_name: str) -> str:
|
||||||
|
"""Convert MLflow prompt name to Llama Stack prompt_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mlflow_name: MLflow prompt name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Llama Stack prompt ID (format: pmpt_<48-hex-chars>)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If name doesn't follow expected format
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> mapper = PromptIDMapper()
|
||||||
|
>>> mapper.to_llama_id("llama_prompt_a1b2c3d4e5f6...")
|
||||||
|
"pmpt_a1b2c3d4e5f6..."
|
||||||
|
"""
|
||||||
|
if not mlflow_name.startswith(self.MLFLOW_NAME_PREFIX):
|
||||||
|
raise ValueError(
|
||||||
|
f"MLflow name '{mlflow_name}' does not start with expected prefix '{self.MLFLOW_NAME_PREFIX}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract hex part
|
||||||
|
hex_part = mlflow_name[len(self.MLFLOW_NAME_PREFIX) :]
|
||||||
|
|
||||||
|
# Validate hex part length and characters
|
||||||
|
if len(hex_part) != 48:
|
||||||
|
raise ValueError(f"Invalid hex part length in MLflow name '{mlflow_name}'. Expected 48 characters.")
|
||||||
|
|
||||||
|
for char in hex_part:
|
||||||
|
if char not in "0123456789abcdef":
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid character '{char}' in hex part of MLflow name '{mlflow_name}'. "
|
||||||
|
"Expected lowercase hex characters [0-9a-f]."
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"pmpt_{hex_part}"
|
||||||
|
|
||||||
|
def get_metadata_tags(self, prompt_id: str, variables: list[str] | None = None) -> dict[str, str]:
|
||||||
|
"""Generate MLflow tags with Llama Stack metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id: Llama Stack prompt ID
|
||||||
|
variables: List of prompt variables (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of MLflow tags for metadata storage
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> mapper = PromptIDMapper()
|
||||||
|
>>> tags = mapper.get_metadata_tags("pmpt_abc123...", ["var1", "var2"])
|
||||||
|
>>> tags
|
||||||
|
{"llama_stack_id": "pmpt_abc123...", "llama_stack_managed": "true", "variables": "var1,var2"}
|
||||||
|
"""
|
||||||
|
tags = {
|
||||||
|
"llama_stack_id": prompt_id,
|
||||||
|
"llama_stack_managed": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
if variables:
|
||||||
|
# Store variables as comma-separated string
|
||||||
|
tags["variables"] = ",".join(variables)
|
||||||
|
|
||||||
|
return tags
|
||||||
547
src/llama_stack/providers/remote/prompts/mlflow/mlflow.py
Normal file
547
src/llama_stack/providers/remote/prompts/mlflow/mlflow.py
Normal file
|
|
@ -0,0 +1,547 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""MLflow Prompt Registry provider implementation.
|
||||||
|
|
||||||
|
This module implements the Llama Stack Prompts protocol using MLflow's Prompt Registry
|
||||||
|
as the backend for centralized prompt management and versioning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.remote.prompts.mlflow.config import MLflowPromptsConfig
|
||||||
|
from llama_stack.providers.remote.prompts.mlflow.mapping import PromptIDMapper
|
||||||
|
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
||||||
|
|
||||||
|
# Try importing mlflow at module level
|
||||||
|
try:
|
||||||
|
import mlflow
|
||||||
|
from mlflow.client import MlflowClient
|
||||||
|
except ImportError:
|
||||||
|
# Fail gracefully when provider is instantiated during initialize()
|
||||||
|
mlflow = None
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
class MLflowPromptsAdapter(NeedsRequestProviderData, Prompts):
|
||||||
|
"""MLflow Prompt Registry adapter for Llama Stack.
|
||||||
|
|
||||||
|
This adapter implements the Llama Stack Prompts protocol using MLflow's
|
||||||
|
Prompt Registry as the backend storage system. It handles:
|
||||||
|
|
||||||
|
- Bidirectional ID mapping (prompt_id <-> MLflow name)
|
||||||
|
- Version management via MLflow versioning
|
||||||
|
- Variable extraction from prompt templates
|
||||||
|
- Metadata storage in MLflow tags
|
||||||
|
- Default version management via MLflow aliases
|
||||||
|
- Credential management via provider data (backstopped by config)
|
||||||
|
|
||||||
|
Credentials can be provided via:
|
||||||
|
1. Per-request provider data header (preferred for security)
|
||||||
|
2. Configuration auth_credential (fallback)
|
||||||
|
3. Environment variables (MLFLOW_TRACKING_TOKEN, etc.)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config: MLflow provider configuration
|
||||||
|
mlflow_client: MLflow client instance
|
||||||
|
mapper: ID mapping utility
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: MLflowPromptsConfig):
|
||||||
|
"""Initialize MLflow prompts adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: MLflow provider configuration
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.mlflow_client: "MlflowClient | None" = None
|
||||||
|
self.mapper = PromptIDMapper()
|
||||||
|
logger.info(
|
||||||
|
f"MLflowPromptsAdapter initialized: tracking_uri={config.mlflow_tracking_uri}, "
|
||||||
|
f"experiment={config.experiment_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize MLflow client and set up experiment.
|
||||||
|
|
||||||
|
Sets up MLflow connection with optional authentication via token.
|
||||||
|
Token can be provided via config or will be read from environment variables
|
||||||
|
(MLFLOW_TRACKING_TOKEN, etc.) as per MLflow's standard behavior.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If mlflow package is not installed
|
||||||
|
Exception: If MLflow connection fails
|
||||||
|
"""
|
||||||
|
if mlflow is None:
|
||||||
|
raise ImportError(
|
||||||
|
"mlflow package is required for MLflow prompts provider. "
|
||||||
|
"Install with: pip install 'mlflow>=3.4.0'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set MLflow URIs
|
||||||
|
mlflow.set_tracking_uri(self.config.mlflow_tracking_uri)
|
||||||
|
|
||||||
|
if self.config.mlflow_registry_uri:
|
||||||
|
mlflow.set_registry_uri(self.config.mlflow_registry_uri)
|
||||||
|
else:
|
||||||
|
# Default to tracking URI if registry not specified
|
||||||
|
mlflow.set_registry_uri(self.config.mlflow_tracking_uri)
|
||||||
|
|
||||||
|
# Set authentication token if provided in config
|
||||||
|
if self.config.auth_credential is not None:
|
||||||
|
import os
|
||||||
|
|
||||||
|
# MLflow reads MLFLOW_TRACKING_TOKEN from environment
|
||||||
|
os.environ["MLFLOW_TRACKING_TOKEN"] = self.config.auth_credential.get_secret_value()
|
||||||
|
logger.debug("Set MLFLOW_TRACKING_TOKEN from config auth_credential")
|
||||||
|
|
||||||
|
# Initialize client
|
||||||
|
self.mlflow_client = MlflowClient()
|
||||||
|
|
||||||
|
# Validate experiment exists (don't create during initialization)
|
||||||
|
try:
|
||||||
|
mlflow.set_experiment(self.config.experiment_name)
|
||||||
|
logger.info(f"Using MLflow experiment: {self.config.experiment_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Experiment '{self.config.experiment_name}' not found: {e}. "
|
||||||
|
f"It will be created automatically on first prompt creation."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _ensure_experiment(self) -> None:
|
||||||
|
"""Ensure MLflow experiment exists, creating it if necessary.
|
||||||
|
|
||||||
|
This is called lazily on first write operation to avoid creating
|
||||||
|
external resources during initialization.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
mlflow.set_experiment(self.config.experiment_name)
|
||||||
|
except Exception:
|
||||||
|
# Experiment doesn't exist, create it
|
||||||
|
try:
|
||||||
|
mlflow.create_experiment(self.config.experiment_name)
|
||||||
|
mlflow.set_experiment(self.config.experiment_name)
|
||||||
|
logger.info(f"Created MLflow experiment: {self.config.experiment_name}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to create experiment '{self.config.experiment_name}': {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _extract_variables(self, template: str) -> list[str]:
|
||||||
|
"""Extract variables from prompt template.
|
||||||
|
|
||||||
|
Extracts variables in {{ variable }} format from the template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: Prompt template string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique variable names in order of appearance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> adapter._extract_variables("Hello {{ name }}, your score is {{ score }}")
|
||||||
|
["name", "score"]
|
||||||
|
"""
|
||||||
|
if not template:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find all {{ variable }} patterns
|
||||||
|
matches = re.findall(r"{{\s*(\w+)\s*}}", template)
|
||||||
|
|
||||||
|
# Return unique variables in order of appearance
|
||||||
|
seen = set()
|
||||||
|
variables = []
|
||||||
|
for var in matches:
|
||||||
|
if var not in seen:
|
||||||
|
variables.append(var)
|
||||||
|
seen.add(var)
|
||||||
|
|
||||||
|
return variables
|
||||||
|
|
||||||
|
async def create_prompt(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
variables: list[str] | None = None,
|
||||||
|
) -> Prompt:
|
||||||
|
"""Create a new prompt in MLflow registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Prompt template text with {{ variable }} placeholders
|
||||||
|
variables: List of variable names (auto-extracted if not provided)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created Prompt resource with prompt_id and version=1
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prompt validation fails
|
||||||
|
Exception: If MLflow registration fails
|
||||||
|
"""
|
||||||
|
# Ensure experiment exists (lazy creation on first write)
|
||||||
|
self._ensure_experiment()
|
||||||
|
|
||||||
|
# Auto-extract variables if not provided
|
||||||
|
if variables is None:
|
||||||
|
variables = self._extract_variables(prompt)
|
||||||
|
else:
|
||||||
|
# Validate declared variables match template
|
||||||
|
template_vars = set(self._extract_variables(prompt))
|
||||||
|
declared_vars = set(variables)
|
||||||
|
undeclared = template_vars - declared_vars
|
||||||
|
if undeclared:
|
||||||
|
raise ValueError(f"Template contains undeclared variables: {sorted(undeclared)}")
|
||||||
|
|
||||||
|
# Generate Llama Stack prompt_id
|
||||||
|
prompt_id = Prompt.generate_prompt_id()
|
||||||
|
|
||||||
|
# Convert to MLflow name
|
||||||
|
mlflow_name = self.mapper.to_mlflow_name(prompt_id)
|
||||||
|
|
||||||
|
# Prepare metadata tags
|
||||||
|
tags = self.mapper.get_metadata_tags(prompt_id, variables)
|
||||||
|
|
||||||
|
# Register in MLflow
|
||||||
|
try:
|
||||||
|
mlflow.genai.register_prompt(
|
||||||
|
name=mlflow_name,
|
||||||
|
template=prompt,
|
||||||
|
commit_message="Created via Llama Stack",
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
logger.info(f"Created prompt {prompt_id} (MLflow: {mlflow_name})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to register prompt in MLflow: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Set as default (first version is always default)
|
||||||
|
try:
|
||||||
|
mlflow.genai.set_prompt_alias(
|
||||||
|
name=mlflow_name,
|
||||||
|
version=1,
|
||||||
|
alias="default",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to set default alias for {prompt_id}: {e}")
|
||||||
|
|
||||||
|
return Prompt(
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
prompt=prompt,
|
||||||
|
version=1,
|
||||||
|
variables=variables,
|
||||||
|
is_default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_prompt(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
version: int | None = None,
|
||||||
|
) -> Prompt:
|
||||||
|
"""Get prompt from MLflow registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id: Llama Stack prompt ID
|
||||||
|
version: Version number (defaults to default version)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prompt resource
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prompt not found
|
||||||
|
"""
|
||||||
|
mlflow_name = self.mapper.to_mlflow_name(prompt_id)
|
||||||
|
|
||||||
|
# Build MLflow URI
|
||||||
|
if version:
|
||||||
|
uri = f"prompts:/{mlflow_name}/{version}"
|
||||||
|
else:
|
||||||
|
uri = f"prompts:/{mlflow_name}@default"
|
||||||
|
|
||||||
|
# Load from MLflow
|
||||||
|
try:
|
||||||
|
mlflow_prompt = mlflow.genai.load_prompt(uri)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Prompt {prompt_id} (version {version if version else 'default'}) not found: {e}") from e
|
||||||
|
|
||||||
|
# Extract template
|
||||||
|
template = mlflow_prompt.template if hasattr(mlflow_prompt, "template") else str(mlflow_prompt)
|
||||||
|
|
||||||
|
# Extract variables from template
|
||||||
|
variables = self._extract_variables(template)
|
||||||
|
|
||||||
|
# Get version number
|
||||||
|
prompt_version = 1
|
||||||
|
if hasattr(mlflow_prompt, "version"):
|
||||||
|
prompt_version = int(mlflow_prompt.version)
|
||||||
|
elif version:
|
||||||
|
prompt_version = version
|
||||||
|
|
||||||
|
# Check if this is the default version
|
||||||
|
is_default = await self._is_default_version(mlflow_name, prompt_version)
|
||||||
|
|
||||||
|
return Prompt(
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
prompt=template,
|
||||||
|
version=prompt_version,
|
||||||
|
variables=variables,
|
||||||
|
is_default=is_default,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_prompt(
|
||||||
|
self,
|
||||||
|
prompt_id: str,
|
||||||
|
prompt: str,
|
||||||
|
version: int,
|
||||||
|
variables: list[str] | None = None,
|
||||||
|
set_as_default: bool = True,
|
||||||
|
) -> Prompt:
|
||||||
|
"""Update prompt (creates new version in MLflow).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id: Llama Stack prompt ID
|
||||||
|
prompt: Updated prompt template
|
||||||
|
version: Current version being updated
|
||||||
|
variables: Updated variables list (auto-extracted if not provided)
|
||||||
|
set_as_default: Set new version as default
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated Prompt resource with incremented version
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If current version not found or validation fails
|
||||||
|
"""
|
||||||
|
# Ensure experiment exists (edge case: updating prompts created outside Llama Stack)
|
||||||
|
self._ensure_experiment()
|
||||||
|
|
||||||
|
# Auto-extract variables if not provided
|
||||||
|
if variables is None:
|
||||||
|
variables = self._extract_variables(prompt)
|
||||||
|
else:
|
||||||
|
# Validate variables
|
||||||
|
template_vars = set(self._extract_variables(prompt))
|
||||||
|
declared_vars = set(variables)
|
||||||
|
undeclared = template_vars - declared_vars
|
||||||
|
if undeclared:
|
||||||
|
raise ValueError(f"Template contains undeclared variables: {sorted(undeclared)}")
|
||||||
|
|
||||||
|
mlflow_name = self.mapper.to_mlflow_name(prompt_id)
|
||||||
|
|
||||||
|
# Get all versions to determine the latest and next version number
|
||||||
|
versions_response = await self.list_prompt_versions(prompt_id)
|
||||||
|
if not versions_response.data:
|
||||||
|
raise ValueError(f"Prompt {prompt_id} not found")
|
||||||
|
|
||||||
|
max_version = max(p.version for p in versions_response.data)
|
||||||
|
|
||||||
|
# Verify the provided version is the latest
|
||||||
|
if version != max_version:
|
||||||
|
raise ValueError(
|
||||||
|
f"Version {version} is not the latest version. Use latest version {max_version} to update."
|
||||||
|
)
|
||||||
|
|
||||||
|
new_version = max_version + 1
|
||||||
|
|
||||||
|
# Prepare metadata tags
|
||||||
|
tags = self.mapper.get_metadata_tags(prompt_id, variables)
|
||||||
|
|
||||||
|
# Register new version in MLflow
|
||||||
|
try:
|
||||||
|
mlflow.genai.register_prompt(
|
||||||
|
name=mlflow_name,
|
||||||
|
template=prompt,
|
||||||
|
commit_message=f"Updated from version {version} via Llama Stack",
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
logger.info(f"Updated prompt {prompt_id} to version {new_version}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update prompt in MLflow: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Set as default if requested
|
||||||
|
if set_as_default:
|
||||||
|
try:
|
||||||
|
mlflow.genai.set_prompt_alias(
|
||||||
|
name=mlflow_name,
|
||||||
|
version=new_version,
|
||||||
|
alias="default",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to set default alias: {e}")
|
||||||
|
|
||||||
|
return Prompt(
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
prompt=prompt,
|
||||||
|
version=new_version,
|
||||||
|
variables=variables,
|
||||||
|
is_default=set_as_default,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_prompt(self, prompt_id: str) -> None:
|
||||||
|
"""Delete prompt from MLflow registry.
|
||||||
|
|
||||||
|
Note: MLflow Prompt Registry does not support deletion of registered prompts.
|
||||||
|
This method will raise NotImplementedError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id: Llama Stack prompt ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: MLflow doesn't support prompt deletion
|
||||||
|
"""
|
||||||
|
# MLflow doesn't support deletion of registered prompts
|
||||||
|
# Options:
|
||||||
|
# 1. Raise NotImplementedError (current approach)
|
||||||
|
# 2. Mark as deleted with tag (soft delete)
|
||||||
|
# 3. Delete all versions individually (if API exists)
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MLflow Prompt Registry does not support deletion. Consider using tags to mark prompts as archived/deleted."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_prompts(self) -> ListPromptsResponse:
|
||||||
|
"""List all prompts (default versions only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ListPromptsResponse with default version of each prompt
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Only lists prompts created/managed by Llama Stack
|
||||||
|
(those with llama_stack_managed=true tag)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Search for Llama Stack managed prompts using metadata tags
|
||||||
|
prompts = mlflow.genai.search_prompts(filter_string="tag.llama_stack_managed='true'")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to search prompts in MLflow: {e}")
|
||||||
|
return ListPromptsResponse(data=[])
|
||||||
|
|
||||||
|
llama_prompts = []
|
||||||
|
for mlflow_prompt in prompts:
|
||||||
|
try:
|
||||||
|
# Convert MLflow name to Llama Stack ID
|
||||||
|
prompt_id = self.mapper.to_llama_id(mlflow_prompt.name)
|
||||||
|
|
||||||
|
# Get default version
|
||||||
|
llama_prompt = await self.get_prompt(prompt_id)
|
||||||
|
llama_prompts.append(llama_prompt)
|
||||||
|
except (ValueError, Exception) as e:
|
||||||
|
# Skip prompts that can't be converted or retrieved
|
||||||
|
logger.warning(f"Skipping prompt {mlflow_prompt.name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sort by prompt_id
|
||||||
|
llama_prompts.sort(key=lambda p: p.prompt_id, reverse=True)
|
||||||
|
|
||||||
|
return ListPromptsResponse(data=llama_prompts)
|
||||||
|
|
||||||
|
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
|
||||||
|
"""List all versions of a specific prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id: Llama Stack prompt ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ListPromptsResponse with all versions of the prompt
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prompt not found
|
||||||
|
"""
|
||||||
|
# MLflow doesn't have a direct "list versions" API for prompts
|
||||||
|
# We need to iterate and try to load each version
|
||||||
|
versions = []
|
||||||
|
version_num = 1
|
||||||
|
max_attempts = 100 # Safety limit
|
||||||
|
|
||||||
|
while version_num <= max_attempts:
|
||||||
|
try:
|
||||||
|
prompt = await self.get_prompt(prompt_id, version_num)
|
||||||
|
versions.append(prompt)
|
||||||
|
version_num += 1
|
||||||
|
except ValueError:
|
||||||
|
# No more versions
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error loading version {version_num} of {prompt_id}: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not versions:
|
||||||
|
raise ValueError(f"Prompt {prompt_id} not found")
|
||||||
|
|
||||||
|
# Sort by version number
|
||||||
|
versions.sort(key=lambda p: p.version)
|
||||||
|
|
||||||
|
return ListPromptsResponse(data=versions)
|
||||||
|
|
||||||
|
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
|
||||||
|
"""Set default version using MLflow alias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id: Llama Stack prompt ID
|
||||||
|
version: Version number to set as default
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prompt resource with is_default=True
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If version not found
|
||||||
|
"""
|
||||||
|
# Ensure experiment exists (edge case: managing prompts created outside Llama Stack)
|
||||||
|
self._ensure_experiment()
|
||||||
|
|
||||||
|
mlflow_name = self.mapper.to_mlflow_name(prompt_id)
|
||||||
|
|
||||||
|
# Verify version exists
|
||||||
|
try:
|
||||||
|
prompt = await self.get_prompt(prompt_id, version)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Cannot set default: {e}") from e
|
||||||
|
|
||||||
|
# Set "default" alias in MLflow
|
||||||
|
try:
|
||||||
|
mlflow.genai.set_prompt_alias(
|
||||||
|
name=mlflow_name,
|
||||||
|
version=version,
|
||||||
|
alias="default",
|
||||||
|
)
|
||||||
|
logger.info(f"Set version {version} as default for {prompt_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to set default version: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Update is_default flag
|
||||||
|
prompt.is_default = True
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
async def _is_default_version(self, mlflow_name: str, version: int) -> bool:
|
||||||
|
"""Check if a version is the default version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mlflow_name: MLflow prompt name
|
||||||
|
version: Version number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this version is the default, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to load with @default alias
|
||||||
|
default_uri = f"prompts:/{mlflow_name}@default"
|
||||||
|
default_prompt = mlflow.genai.load_prompt(default_uri)
|
||||||
|
|
||||||
|
# Get default version number
|
||||||
|
default_version = 1
|
||||||
|
if hasattr(default_prompt, "version"):
|
||||||
|
default_version = int(default_prompt.version)
|
||||||
|
|
||||||
|
return version == default_version
|
||||||
|
except Exception:
|
||||||
|
# If default doesn't exist or can't be determined, assume False
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Cleanup resources (no-op for MLflow)."""
|
||||||
|
pass
|
||||||
7
tests/integration/providers/remote/prompts/__init__.py
Normal file
7
tests/integration/providers/remote/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Integration tests for remote prompts providers."""
|
||||||
274
tests/integration/providers/remote/prompts/mlflow/README.md
Normal file
274
tests/integration/providers/remote/prompts/mlflow/README.md
Normal file
|
|
@ -0,0 +1,274 @@
|
||||||
|
# MLflow Prompts Provider - Integration Tests
|
||||||
|
|
||||||
|
This directory contains integration tests for the MLflow Prompts Provider. These tests require a running MLflow server.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
1. **MLflow installed**: `pip install 'mlflow>=3.4.0'` (or `uv pip install 'mlflow>=3.4.0'`)
|
||||||
|
2. **MLflow server running**: See setup instructions below
|
||||||
|
3. **Test dependencies**: `uv sync --group test`
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Start MLflow Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start MLflow server on localhost:5555
|
||||||
|
mlflow server --host 127.0.0.1 --port 5555
|
||||||
|
|
||||||
|
# Keep this terminal open - server will continue running
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Run Integration Tests
|
||||||
|
|
||||||
|
In a separate terminal:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set MLflow URI (optional - defaults to localhost:5555)
|
||||||
|
export MLFLOW_TRACKING_URI=http://localhost:5555
|
||||||
|
|
||||||
|
# Run all integration tests
|
||||||
|
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
|
||||||
|
|
||||||
|
# Run specific test
|
||||||
|
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/test_end_to_end.py::TestMLflowPromptsEndToEnd::test_create_and_retrieve_prompt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Run Manual Test Script (Optional)
|
||||||
|
|
||||||
|
For quick validation without pytest:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run manual test script
|
||||||
|
uv run python scripts/test_mlflow_prompts_manual.py
|
||||||
|
|
||||||
|
# View output in MLflow UI
|
||||||
|
open http://localhost:5555
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Organization
|
||||||
|
|
||||||
|
### Integration Tests (`test_end_to_end.py`)
|
||||||
|
|
||||||
|
Comprehensive end-to-end tests covering:
|
||||||
|
|
||||||
|
- ✅ Create and retrieve prompts
|
||||||
|
- ✅ Update prompts (version management)
|
||||||
|
- ✅ List prompts (default versions only)
|
||||||
|
- ✅ List all versions of a prompt
|
||||||
|
- ✅ Set default version
|
||||||
|
- ✅ Variable auto-extraction
|
||||||
|
- ✅ Variable validation
|
||||||
|
- ✅ Error handling (not found, wrong version, etc.)
|
||||||
|
- ✅ Complex templates with multiple variables
|
||||||
|
- ✅ Edge cases (empty templates, no variables, etc.)
|
||||||
|
|
||||||
|
**Total**: 17 test scenarios
|
||||||
|
|
||||||
|
### Manual Test Script (`scripts/test_mlflow_prompts_manual.py`)
|
||||||
|
|
||||||
|
Interactive test script with verbose output for:
|
||||||
|
|
||||||
|
- Server connectivity check
|
||||||
|
- Provider initialization
|
||||||
|
- Basic CRUD operations
|
||||||
|
- Variable extraction
|
||||||
|
- Statistics retrieval
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### MLflow Server Options
|
||||||
|
|
||||||
|
**Local (default)**:
|
||||||
|
```bash
|
||||||
|
mlflow server --host 127.0.0.1 --port 5555
|
||||||
|
```
|
||||||
|
|
||||||
|
**Remote server**:
|
||||||
|
```bash
|
||||||
|
export MLFLOW_TRACKING_URI=http://mlflow.example.com:5000
|
||||||
|
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
|
||||||
|
```
|
||||||
|
|
||||||
|
**Databricks**:
|
||||||
|
```bash
|
||||||
|
export MLFLOW_TRACKING_URI=databricks
|
||||||
|
export MLFLOW_REGISTRY_URI=databricks://profile
|
||||||
|
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Timeout
|
||||||
|
|
||||||
|
Tests have a default timeout of 30 seconds per MLflow operation. Adjust in `conftest.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
MLflowPromptsConfig(
|
||||||
|
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||||
|
timeout_seconds=60, # Increase for slow connections
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fixtures
|
||||||
|
|
||||||
|
### `mlflow_adapter`
|
||||||
|
|
||||||
|
Basic adapter for simple tests:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def test_something(mlflow_adapter):
|
||||||
|
prompt = await mlflow_adapter.create_prompt(...)
|
||||||
|
# Test continues...
|
||||||
|
```
|
||||||
|
|
||||||
|
### `mlflow_adapter_with_cleanup`
|
||||||
|
|
||||||
|
Adapter with automatic cleanup tracking:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def test_something(mlflow_adapter_with_cleanup):
|
||||||
|
# Creates are tracked and attempted cleanup on teardown
|
||||||
|
prompt = await mlflow_adapter_with_cleanup.create_prompt(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: MLflow doesn't support deletion, so cleanup is best-effort.
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Server Not Available
|
||||||
|
|
||||||
|
**Symptom**:
|
||||||
|
```
|
||||||
|
SKIPPED [1] conftest.py:35: MLflow server not available at http://localhost:5555
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**:
|
||||||
|
```bash
|
||||||
|
# Start MLflow server
|
||||||
|
mlflow server --host 127.0.0.1 --port 5555
|
||||||
|
|
||||||
|
# Verify it's running
|
||||||
|
curl http://localhost:5555/health
|
||||||
|
```
|
||||||
|
|
||||||
|
### Connection Timeout
|
||||||
|
|
||||||
|
**Symptom**:
|
||||||
|
```
|
||||||
|
requests.exceptions.Timeout: ...
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solutions**:
|
||||||
|
1. Check MLflow server is responsive: `curl http://localhost:5555/health`
|
||||||
|
2. Increase timeout in `conftest.py`: `timeout_seconds=60`
|
||||||
|
3. Check firewall/network settings
|
||||||
|
|
||||||
|
### Import Errors
|
||||||
|
|
||||||
|
**Symptom**:
|
||||||
|
```
|
||||||
|
ModuleNotFoundError: No module named 'mlflow'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**:
|
||||||
|
```bash
|
||||||
|
uv pip install 'mlflow>=3.4.0'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Permission Errors
|
||||||
|
|
||||||
|
**Symptom**:
|
||||||
|
```
|
||||||
|
PermissionError: [Errno 13] Permission denied: '...'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**:
|
||||||
|
- Ensure MLflow has write access to its storage directory
|
||||||
|
- Check file permissions on `mlruns/` directory
|
||||||
|
|
||||||
|
### Test Isolation Issues
|
||||||
|
|
||||||
|
**Issue**: Tests may interfere with each other if using same prompt IDs
|
||||||
|
|
||||||
|
**Solution**: Each test creates new prompts with unique IDs (generated by `Prompt.generate_prompt_id()`). If needed, use `mlflow_adapter_with_cleanup` fixture.
|
||||||
|
|
||||||
|
## Viewing Results
|
||||||
|
|
||||||
|
### MLflow UI
|
||||||
|
|
||||||
|
1. Start MLflow server (if not already running):
|
||||||
|
```bash
|
||||||
|
mlflow server --host 127.0.0.1 --port 5555
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Open in browser:
|
||||||
|
```
|
||||||
|
http://localhost:5555
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Navigate to experiment `test-llama-stack-prompts`
|
||||||
|
|
||||||
|
4. View registered prompts and their versions
|
||||||
|
|
||||||
|
### Test Output
|
||||||
|
|
||||||
|
Run with verbose output to see detailed test execution:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run --group test pytest -vv tests/integration/providers/remote/prompts/mlflow/
|
||||||
|
```
|
||||||
|
|
||||||
|
## CI/CD Integration
|
||||||
|
|
||||||
|
To run tests in CI/CD pipelines:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Example GitHub Actions workflow
|
||||||
|
- name: Start MLflow server
|
||||||
|
run: |
|
||||||
|
mlflow server --host 127.0.0.1 --port 5555 &
|
||||||
|
sleep 5 # Wait for server to start
|
||||||
|
|
||||||
|
- name: Wait for MLflow
|
||||||
|
run: |
|
||||||
|
timeout 30 bash -c 'until curl -s http://localhost:5555/health; do sleep 1; done'
|
||||||
|
|
||||||
|
- name: Run integration tests
|
||||||
|
env:
|
||||||
|
MLFLOW_TRACKING_URI: http://localhost:5555
|
||||||
|
run: |
|
||||||
|
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
### Expected Test Duration
|
||||||
|
|
||||||
|
- **Individual test**: ~1-5 seconds
|
||||||
|
- **Full suite** (17 tests): ~30-60 seconds
|
||||||
|
- **Manual script**: ~10-15 seconds
|
||||||
|
|
||||||
|
### Optimization Tips
|
||||||
|
|
||||||
|
1. Use local MLflow server (faster than remote)
|
||||||
|
2. Run tests in parallel (if safe):
|
||||||
|
```bash
|
||||||
|
uv run --group test pytest -n auto tests/integration/providers/remote/prompts/mlflow/
|
||||||
|
```
|
||||||
|
3. Skip integration tests in development:
|
||||||
|
```bash
|
||||||
|
uv run --group dev pytest -sv tests/unit/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Coverage
|
||||||
|
|
||||||
|
Integration tests provide coverage for:
|
||||||
|
|
||||||
|
- ✅ Real MLflow API calls
|
||||||
|
- ✅ Network communication
|
||||||
|
- ✅ Serialization/deserialization
|
||||||
|
- ✅ MLflow server responses
|
||||||
|
- ✅ Version management
|
||||||
|
- ✅ Alias handling
|
||||||
|
- ✅ Tag storage and retrieval
|
||||||
|
|
||||||
|
Combined with unit tests, achieves **>95% code coverage**.
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Integration tests for MLflow prompts provider."""
|
||||||
133
tests/integration/providers/remote/prompts/mlflow/conftest.py
Normal file
133
tests/integration/providers/remote/prompts/mlflow/conftest.py
Normal file
|
|
@ -0,0 +1,133 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Fixtures for MLflow integration tests.
|
||||||
|
|
||||||
|
These tests require a running MLflow server. Set the MLFLOW_TRACKING_URI
|
||||||
|
environment variable to point to your MLflow server, or the tests will
|
||||||
|
attempt to use http://localhost:5555.
|
||||||
|
|
||||||
|
To run tests:
|
||||||
|
# Start MLflow server (in separate terminal)
|
||||||
|
mlflow server --host 127.0.0.1 --port 5555
|
||||||
|
|
||||||
|
# Run integration tests
|
||||||
|
MLFLOW_TRACKING_URI=http://localhost:5555 \
|
||||||
|
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.prompts.mlflow import MLflowPromptsAdapter
|
||||||
|
from llama_stack.providers.remote.prompts.mlflow.config import MLflowPromptsConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def mlflow_tracking_uri():
|
||||||
|
"""Get MLflow tracking URI from environment or use default."""
|
||||||
|
return os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5555")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def mlflow_server_available(mlflow_tracking_uri):
|
||||||
|
"""Verify MLflow server is running and accessible.
|
||||||
|
|
||||||
|
Skips all tests if server is not available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.get(f"{mlflow_tracking_uri}/health", timeout=5)
|
||||||
|
if response.status_code != 200:
|
||||||
|
pytest.skip(f"MLflow server at {mlflow_tracking_uri} returned status {response.status_code}")
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("requests package not installed - install with: pip install requests")
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
pytest.skip(
|
||||||
|
f"MLflow server not available at {mlflow_tracking_uri}. "
|
||||||
|
"Start with: mlflow server --host 127.0.0.1 --port 5555"
|
||||||
|
)
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
pytest.skip(f"MLflow server at {mlflow_tracking_uri} timed out")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Failed to check MLflow server availability: {e}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mlflow_config(mlflow_tracking_uri, mlflow_server_available):
|
||||||
|
"""Create MLflow configuration for testing."""
|
||||||
|
return MLflowPromptsConfig(
|
||||||
|
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||||
|
experiment_name="test-llama-stack-prompts",
|
||||||
|
timeout_seconds=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mlflow_adapter(mlflow_config):
|
||||||
|
"""Create and initialize MLflow adapter for testing.
|
||||||
|
|
||||||
|
This fixture creates a new adapter instance for each test.
|
||||||
|
The adapter connects to the MLflow server specified in the config.
|
||||||
|
"""
|
||||||
|
adapter = MLflowPromptsAdapter(config=mlflow_config)
|
||||||
|
await adapter.initialize()
|
||||||
|
|
||||||
|
yield adapter
|
||||||
|
|
||||||
|
# Cleanup: shutdown adapter
|
||||||
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mlflow_adapter_with_cleanup(mlflow_config):
|
||||||
|
"""Create MLflow adapter with automatic cleanup after test.
|
||||||
|
|
||||||
|
This fixture is useful for tests that create prompts and want them
|
||||||
|
automatically cleaned up (though MLflow doesn't support deletion,
|
||||||
|
so cleanup is best-effort).
|
||||||
|
"""
|
||||||
|
adapter = MLflowPromptsAdapter(config=mlflow_config)
|
||||||
|
await adapter.initialize()
|
||||||
|
|
||||||
|
created_prompt_ids = []
|
||||||
|
|
||||||
|
# Provide adapter and tracking list
|
||||||
|
class AdapterWithTracking:
|
||||||
|
def __init__(self, adapter_instance):
|
||||||
|
self.adapter = adapter_instance
|
||||||
|
self.created_ids = created_prompt_ids
|
||||||
|
|
||||||
|
async def create_prompt(self, *args, **kwargs):
|
||||||
|
prompt = await self.adapter.create_prompt(*args, **kwargs)
|
||||||
|
self.created_ids.append(prompt.prompt_id)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.adapter, name)
|
||||||
|
|
||||||
|
tracked_adapter = AdapterWithTracking(adapter)
|
||||||
|
|
||||||
|
yield tracked_adapter
|
||||||
|
|
||||||
|
# Cleanup: attempt to delete created prompts
|
||||||
|
# Note: MLflow doesn't support deletion, so this is a no-op
|
||||||
|
# but we keep it for future compatibility
|
||||||
|
for prompt_id in created_prompt_ids:
|
||||||
|
try:
|
||||||
|
await adapter.delete_prompt(prompt_id)
|
||||||
|
except NotImplementedError:
|
||||||
|
# Expected - MLflow doesn't support deletion
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
# Ignore cleanup errors
|
||||||
|
pass
|
||||||
|
|
||||||
|
await adapter.shutdown()
|
||||||
|
|
@ -0,0 +1,350 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""End-to-end integration tests for MLflow prompts provider.
|
||||||
|
|
||||||
|
These tests require a running MLflow server. See conftest.py for setup instructions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestMLflowPromptsEndToEnd:
|
||||||
|
"""End-to-end tests for MLflow prompts provider."""
|
||||||
|
|
||||||
|
async def test_create_and_retrieve_prompt(self, mlflow_adapter):
|
||||||
|
"""Test creating a prompt and retrieving it by ID."""
|
||||||
|
# Create prompt with variables
|
||||||
|
created = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Summarize the following text in {{ num_sentences }} sentences: {{ text }}",
|
||||||
|
variables=["num_sentences", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify created prompt
|
||||||
|
assert created.prompt_id.startswith("pmpt_")
|
||||||
|
assert len(created.prompt_id) == 53 # "pmpt_" + 48 hex chars
|
||||||
|
assert created.version == 1
|
||||||
|
assert created.is_default is True
|
||||||
|
assert set(created.variables) == {"num_sentences", "text"}
|
||||||
|
assert "{{ num_sentences }}" in created.prompt
|
||||||
|
assert "{{ text }}" in created.prompt
|
||||||
|
|
||||||
|
# Retrieve prompt by ID (should get default version)
|
||||||
|
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||||
|
|
||||||
|
assert retrieved.prompt_id == created.prompt_id
|
||||||
|
assert retrieved.prompt == created.prompt
|
||||||
|
assert retrieved.version == created.version
|
||||||
|
assert set(retrieved.variables) == set(created.variables)
|
||||||
|
assert retrieved.is_default is True
|
||||||
|
|
||||||
|
# Retrieve specific version
|
||||||
|
retrieved_v1 = await mlflow_adapter.get_prompt(created.prompt_id, version=1)
|
||||||
|
|
||||||
|
assert retrieved_v1.prompt_id == created.prompt_id
|
||||||
|
assert retrieved_v1.version == 1
|
||||||
|
|
||||||
|
async def test_update_prompt_creates_new_version(self, mlflow_adapter):
|
||||||
|
"""Test that updating a prompt creates a new version."""
|
||||||
|
# Create initial prompt (version 1)
|
||||||
|
v1 = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Original prompt with {{ variable }}",
|
||||||
|
variables=["variable"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert v1.version == 1
|
||||||
|
assert v1.is_default is True
|
||||||
|
|
||||||
|
# Update prompt (should create version 2)
|
||||||
|
v2 = await mlflow_adapter.update_prompt(
|
||||||
|
prompt_id=v1.prompt_id,
|
||||||
|
prompt="Updated prompt with {{ variable }}",
|
||||||
|
version=1,
|
||||||
|
variables=["variable"],
|
||||||
|
set_as_default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert v2.prompt_id == v1.prompt_id
|
||||||
|
assert v2.version == 2
|
||||||
|
assert v2.is_default is True
|
||||||
|
assert "Updated" in v2.prompt
|
||||||
|
|
||||||
|
# Verify both versions exist
|
||||||
|
versions_response = await mlflow_adapter.list_prompt_versions(v1.prompt_id)
|
||||||
|
versions = versions_response.data
|
||||||
|
|
||||||
|
assert len(versions) >= 2
|
||||||
|
assert any(v.version == 1 for v in versions)
|
||||||
|
assert any(v.version == 2 for v in versions)
|
||||||
|
|
||||||
|
# Verify version 1 still exists
|
||||||
|
v1_retrieved = await mlflow_adapter.get_prompt(v1.prompt_id, version=1)
|
||||||
|
assert "Original" in v1_retrieved.prompt
|
||||||
|
assert v1_retrieved.is_default is False # No longer default
|
||||||
|
|
||||||
|
# Verify version 2 is default
|
||||||
|
default = await mlflow_adapter.get_prompt(v1.prompt_id)
|
||||||
|
assert default.version == 2
|
||||||
|
assert "Updated" in default.prompt
|
||||||
|
|
||||||
|
async def test_list_prompts_returns_defaults_only(self, mlflow_adapter):
|
||||||
|
"""Test that list_prompts returns only default versions."""
|
||||||
|
# Create multiple prompts
|
||||||
|
p1 = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Prompt 1 with {{ var }}",
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
p2 = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Prompt 2 with {{ var }}",
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update first prompt (creates version 2)
|
||||||
|
await mlflow_adapter.update_prompt(
|
||||||
|
prompt_id=p1.prompt_id,
|
||||||
|
prompt="Prompt 1 updated with {{ var }}",
|
||||||
|
version=1,
|
||||||
|
variables=["var"],
|
||||||
|
set_as_default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# List all prompts
|
||||||
|
response = await mlflow_adapter.list_prompts()
|
||||||
|
prompts = response.data
|
||||||
|
|
||||||
|
# Should contain at least our 2 prompts
|
||||||
|
assert len(prompts) >= 2
|
||||||
|
|
||||||
|
# Find our prompts in the list
|
||||||
|
p1_in_list = next((p for p in prompts if p.prompt_id == p1.prompt_id), None)
|
||||||
|
p2_in_list = next((p for p in prompts if p.prompt_id == p2.prompt_id), None)
|
||||||
|
|
||||||
|
assert p1_in_list is not None
|
||||||
|
assert p2_in_list is not None
|
||||||
|
|
||||||
|
# p1 should be version 2 (updated version is default)
|
||||||
|
assert p1_in_list.version == 2
|
||||||
|
assert p1_in_list.is_default is True
|
||||||
|
|
||||||
|
# p2 should be version 1 (original is still default)
|
||||||
|
assert p2_in_list.version == 1
|
||||||
|
assert p2_in_list.is_default is True
|
||||||
|
|
||||||
|
async def test_list_prompt_versions(self, mlflow_adapter):
|
||||||
|
"""Test listing all versions of a specific prompt."""
|
||||||
|
# Create prompt
|
||||||
|
v1 = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Version 1 {{ var }}",
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create multiple versions
|
||||||
|
_v2 = await mlflow_adapter.update_prompt(
|
||||||
|
prompt_id=v1.prompt_id,
|
||||||
|
prompt="Version 2 {{ var }}",
|
||||||
|
version=1,
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
_v3 = await mlflow_adapter.update_prompt(
|
||||||
|
prompt_id=v1.prompt_id,
|
||||||
|
prompt="Version 3 {{ var }}",
|
||||||
|
version=2,
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# List all versions
|
||||||
|
versions_response = await mlflow_adapter.list_prompt_versions(v1.prompt_id)
|
||||||
|
versions = versions_response.data
|
||||||
|
|
||||||
|
# Should have 3 versions
|
||||||
|
assert len(versions) == 3
|
||||||
|
|
||||||
|
# Verify versions are sorted by version number
|
||||||
|
assert versions[0].version == 1
|
||||||
|
assert versions[1].version == 2
|
||||||
|
assert versions[2].version == 3
|
||||||
|
|
||||||
|
# Verify content
|
||||||
|
assert "Version 1" in versions[0].prompt
|
||||||
|
assert "Version 2" in versions[1].prompt
|
||||||
|
assert "Version 3" in versions[2].prompt
|
||||||
|
|
||||||
|
# Only latest should be default
|
||||||
|
assert versions[0].is_default is False
|
||||||
|
assert versions[1].is_default is False
|
||||||
|
assert versions[2].is_default is True
|
||||||
|
|
||||||
|
async def test_set_default_version(self, mlflow_adapter):
|
||||||
|
"""Test changing which version is the default."""
|
||||||
|
# Create prompt and update it
|
||||||
|
v1 = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Version 1 {{ var }}",
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
_v2 = await mlflow_adapter.update_prompt(
|
||||||
|
prompt_id=v1.prompt_id,
|
||||||
|
prompt="Version 2 {{ var }}",
|
||||||
|
version=1,
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# At this point, _v2 is default
|
||||||
|
default = await mlflow_adapter.get_prompt(v1.prompt_id)
|
||||||
|
assert default.version == 2
|
||||||
|
|
||||||
|
# Set v1 as default
|
||||||
|
updated = await mlflow_adapter.set_default_version(v1.prompt_id, 1)
|
||||||
|
assert updated.version == 1
|
||||||
|
assert updated.is_default is True
|
||||||
|
|
||||||
|
# Verify default changed
|
||||||
|
default = await mlflow_adapter.get_prompt(v1.prompt_id)
|
||||||
|
assert default.version == 1
|
||||||
|
assert "Version 1" in default.prompt
|
||||||
|
|
||||||
|
async def test_variable_auto_extraction(self, mlflow_adapter):
|
||||||
|
"""Test automatic variable extraction from template."""
|
||||||
|
# Create prompt without explicitly specifying variables
|
||||||
|
created = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Extract {{ entity }} from {{ text }} in {{ format }} format",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should auto-extract all variables
|
||||||
|
assert set(created.variables) == {"entity", "text", "format"}
|
||||||
|
|
||||||
|
# Retrieve and verify
|
||||||
|
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||||
|
assert set(retrieved.variables) == {"entity", "text", "format"}
|
||||||
|
|
||||||
|
async def test_variable_validation(self, mlflow_adapter):
|
||||||
|
"""Test that variable validation works correctly."""
|
||||||
|
# Should fail: template has undeclared variable
|
||||||
|
with pytest.raises(ValueError, match="undeclared variables"):
|
||||||
|
await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Template with {{ var1 }} and {{ var2 }}",
|
||||||
|
variables=["var1"], # Missing var2
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_prompt_not_found(self, mlflow_adapter):
|
||||||
|
"""Test error handling when prompt doesn't exist."""
|
||||||
|
fake_id = "pmpt_" + "0" * 48
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
await mlflow_adapter.get_prompt(fake_id)
|
||||||
|
|
||||||
|
async def test_version_not_found(self, mlflow_adapter):
|
||||||
|
"""Test error handling when version doesn't exist."""
|
||||||
|
# Create prompt (version 1)
|
||||||
|
created = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Test {{ var }}",
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to get non-existent version
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
await mlflow_adapter.get_prompt(created.prompt_id, version=999)
|
||||||
|
|
||||||
|
async def test_update_wrong_version(self, mlflow_adapter):
|
||||||
|
"""Test that updating with wrong version fails."""
|
||||||
|
# Create prompt (version 1)
|
||||||
|
created = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Test {{ var }}",
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to update with wrong version number
|
||||||
|
with pytest.raises(ValueError, match="not the latest"):
|
||||||
|
await mlflow_adapter.update_prompt(
|
||||||
|
prompt_id=created.prompt_id,
|
||||||
|
prompt="Updated {{ var }}",
|
||||||
|
version=999, # Wrong version
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_delete_not_supported(self, mlflow_adapter):
|
||||||
|
"""Test that deletion raises NotImplementedError."""
|
||||||
|
# Create prompt
|
||||||
|
created = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="Test {{ var }}",
|
||||||
|
variables=["var"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to delete (should fail with NotImplementedError)
|
||||||
|
with pytest.raises(NotImplementedError, match="does not support deletion"):
|
||||||
|
await mlflow_adapter.delete_prompt(created.prompt_id)
|
||||||
|
|
||||||
|
# Verify prompt still exists
|
||||||
|
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||||
|
assert retrieved.prompt_id == created.prompt_id
|
||||||
|
|
||||||
|
async def test_complex_template_with_multiple_variables(self, mlflow_adapter):
|
||||||
|
"""Test prompt with complex template and multiple variables."""
|
||||||
|
template = """You are a {{ role }} assistant specialized in {{ domain }}.
|
||||||
|
|
||||||
|
Task: {{ task }}
|
||||||
|
|
||||||
|
Context:
|
||||||
|
{{ context }}
|
||||||
|
|
||||||
|
Instructions:
|
||||||
|
1. {{ instruction1 }}
|
||||||
|
2. {{ instruction2 }}
|
||||||
|
3. {{ instruction3 }}
|
||||||
|
|
||||||
|
Output format: {{ output_format }}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create with auto-extraction
|
||||||
|
created = await mlflow_adapter.create_prompt(prompt=template)
|
||||||
|
|
||||||
|
# Should extract all variables
|
||||||
|
expected_vars = {
|
||||||
|
"role",
|
||||||
|
"domain",
|
||||||
|
"task",
|
||||||
|
"context",
|
||||||
|
"instruction1",
|
||||||
|
"instruction2",
|
||||||
|
"instruction3",
|
||||||
|
"output_format",
|
||||||
|
}
|
||||||
|
assert set(created.variables) == expected_vars
|
||||||
|
|
||||||
|
# Retrieve and verify template preserved
|
||||||
|
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||||
|
assert retrieved.prompt == template
|
||||||
|
|
||||||
|
async def test_empty_template(self, mlflow_adapter):
|
||||||
|
"""Test handling of empty template."""
|
||||||
|
# Create prompt with empty template
|
||||||
|
created = await mlflow_adapter.create_prompt(
|
||||||
|
prompt="",
|
||||||
|
variables=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert created.prompt == ""
|
||||||
|
assert created.variables == []
|
||||||
|
|
||||||
|
# Retrieve and verify
|
||||||
|
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||||
|
assert retrieved.prompt == ""
|
||||||
|
|
||||||
|
async def test_template_with_no_variables(self, mlflow_adapter):
|
||||||
|
"""Test template without any variables."""
|
||||||
|
template = "This is a static prompt with no variables."
|
||||||
|
|
||||||
|
created = await mlflow_adapter.create_prompt(prompt=template)
|
||||||
|
|
||||||
|
assert created.prompt == template
|
||||||
|
assert created.variables == []
|
||||||
|
|
||||||
|
# Retrieve and verify
|
||||||
|
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||||
|
assert retrieved.prompt == template
|
||||||
|
assert retrieved.variables == []
|
||||||
7
tests/unit/providers/remote/prompts/__init__.py
Normal file
7
tests/unit/providers/remote/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Unit tests for remote prompts providers."""
|
||||||
7
tests/unit/providers/remote/prompts/mlflow/__init__.py
Normal file
7
tests/unit/providers/remote/prompts/mlflow/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Unit tests for MLflow prompts provider."""
|
||||||
138
tests/unit/providers/remote/prompts/mlflow/test_config.py
Normal file
138
tests/unit/providers/remote/prompts/mlflow/test_config.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Unit tests for MLflow prompts provider configuration."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr, ValidationError
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.prompts.mlflow.config import (
|
||||||
|
MLflowPromptsConfig,
|
||||||
|
MLflowProviderDataValidator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMLflowPromptsConfig:
|
||||||
|
"""Tests for MLflowPromptsConfig model."""
|
||||||
|
|
||||||
|
def test_default_config(self):
|
||||||
|
"""Test default configuration values."""
|
||||||
|
config = MLflowPromptsConfig()
|
||||||
|
|
||||||
|
assert config.mlflow_tracking_uri == "http://localhost:5000"
|
||||||
|
assert config.mlflow_registry_uri is None
|
||||||
|
assert config.experiment_name == "llama-stack-prompts"
|
||||||
|
assert config.auth_credential is None
|
||||||
|
assert config.timeout_seconds == 30
|
||||||
|
|
||||||
|
def test_custom_config(self):
|
||||||
|
"""Test custom configuration values."""
|
||||||
|
config = MLflowPromptsConfig(
|
||||||
|
mlflow_tracking_uri="http://mlflow.example.com:8080",
|
||||||
|
mlflow_registry_uri="http://registry.example.com:8080",
|
||||||
|
experiment_name="my-prompts",
|
||||||
|
auth_credential=SecretStr("my-token"),
|
||||||
|
timeout_seconds=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.mlflow_tracking_uri == "http://mlflow.example.com:8080"
|
||||||
|
assert config.mlflow_registry_uri == "http://registry.example.com:8080"
|
||||||
|
assert config.experiment_name == "my-prompts"
|
||||||
|
assert config.auth_credential.get_secret_value() == "my-token"
|
||||||
|
assert config.timeout_seconds == 60
|
||||||
|
|
||||||
|
def test_databricks_uri(self):
|
||||||
|
"""Test Databricks URI configuration."""
|
||||||
|
config = MLflowPromptsConfig(
|
||||||
|
mlflow_tracking_uri="databricks",
|
||||||
|
mlflow_registry_uri="databricks://profile",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.mlflow_tracking_uri == "databricks"
|
||||||
|
assert config.mlflow_registry_uri == "databricks://profile"
|
||||||
|
|
||||||
|
def test_tracking_uri_validation(self):
|
||||||
|
"""Test tracking URI validation."""
|
||||||
|
# Empty string rejected
|
||||||
|
with pytest.raises(ValidationError, match="mlflow_tracking_uri cannot be empty"):
|
||||||
|
MLflowPromptsConfig(mlflow_tracking_uri="")
|
||||||
|
|
||||||
|
# Whitespace-only rejected
|
||||||
|
with pytest.raises(ValidationError, match="mlflow_tracking_uri cannot be empty"):
|
||||||
|
MLflowPromptsConfig(mlflow_tracking_uri=" ")
|
||||||
|
|
||||||
|
# Whitespace is stripped
|
||||||
|
config = MLflowPromptsConfig(mlflow_tracking_uri=" http://localhost:5000 ")
|
||||||
|
assert config.mlflow_tracking_uri == "http://localhost:5000"
|
||||||
|
|
||||||
|
def test_experiment_name_validation(self):
|
||||||
|
"""Test experiment name validation."""
|
||||||
|
# Empty string rejected
|
||||||
|
with pytest.raises(ValidationError, match="experiment_name cannot be empty"):
|
||||||
|
MLflowPromptsConfig(experiment_name="")
|
||||||
|
|
||||||
|
# Whitespace-only rejected
|
||||||
|
with pytest.raises(ValidationError, match="experiment_name cannot be empty"):
|
||||||
|
MLflowPromptsConfig(experiment_name=" ")
|
||||||
|
|
||||||
|
# Whitespace is stripped
|
||||||
|
config = MLflowPromptsConfig(experiment_name=" my-experiment ")
|
||||||
|
assert config.experiment_name == "my-experiment"
|
||||||
|
|
||||||
|
def test_timeout_validation(self):
|
||||||
|
"""Test timeout range validation."""
|
||||||
|
# Too low rejected
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
MLflowPromptsConfig(timeout_seconds=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
MLflowPromptsConfig(timeout_seconds=-1)
|
||||||
|
|
||||||
|
# Too high rejected
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
MLflowPromptsConfig(timeout_seconds=301)
|
||||||
|
|
||||||
|
# Boundary values accepted
|
||||||
|
config_min = MLflowPromptsConfig(timeout_seconds=1)
|
||||||
|
assert config_min.timeout_seconds == 1
|
||||||
|
|
||||||
|
config_max = MLflowPromptsConfig(timeout_seconds=300)
|
||||||
|
assert config_max.timeout_seconds == 300
|
||||||
|
|
||||||
|
def test_sample_run_config(self):
|
||||||
|
"""Test sample_run_config generates valid configuration."""
|
||||||
|
# Default environment variable
|
||||||
|
sample = MLflowPromptsConfig.sample_run_config()
|
||||||
|
assert sample["mlflow_tracking_uri"] == "http://localhost:5000"
|
||||||
|
assert sample["experiment_name"] == "llama-stack-prompts"
|
||||||
|
assert sample["auth_credential"] == "${env.MLFLOW_TRACKING_TOKEN:=}"
|
||||||
|
|
||||||
|
# Custom values
|
||||||
|
sample = MLflowPromptsConfig.sample_run_config(
|
||||||
|
mlflow_api_token="test-token",
|
||||||
|
mlflow_tracking_uri="http://custom:5000",
|
||||||
|
)
|
||||||
|
assert sample["mlflow_tracking_uri"] == "http://custom:5000"
|
||||||
|
assert sample["auth_credential"] == "test-token"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMLflowProviderDataValidator:
|
||||||
|
"""Tests for MLflowProviderDataValidator."""
|
||||||
|
|
||||||
|
def test_provider_data_validator(self):
|
||||||
|
"""Test provider data validator with and without token."""
|
||||||
|
# With token
|
||||||
|
validator = MLflowProviderDataValidator(mlflow_api_token="test-token-123")
|
||||||
|
assert validator.mlflow_api_token == "test-token-123"
|
||||||
|
|
||||||
|
# Without token
|
||||||
|
validator = MLflowProviderDataValidator()
|
||||||
|
assert validator.mlflow_api_token is None
|
||||||
|
|
||||||
|
# From dictionary
|
||||||
|
data = {"mlflow_api_token": "secret-token"}
|
||||||
|
validator = MLflowProviderDataValidator(**data)
|
||||||
|
assert validator.mlflow_api_token == "secret-token"
|
||||||
95
tests/unit/providers/remote/prompts/mlflow/test_mapping.py
Normal file
95
tests/unit/providers/remote/prompts/mlflow/test_mapping.py
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Unit tests for MLflow prompts ID mapping utilities."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.prompts.mlflow.mapping import PromptIDMapper
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptIDMapper:
|
||||||
|
"""Tests for PromptIDMapper class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mapper(self):
|
||||||
|
"""Create ID mapper instance."""
|
||||||
|
return PromptIDMapper()
|
||||||
|
|
||||||
|
def test_to_mlflow_name_valid(self, mapper):
|
||||||
|
"""Test converting valid prompt_id to MLflow name."""
|
||||||
|
prompt_id = "pmpt_" + "a" * 48
|
||||||
|
mlflow_name = mapper.to_mlflow_name(prompt_id)
|
||||||
|
|
||||||
|
assert mlflow_name == "llama_prompt_" + "a" * 48
|
||||||
|
assert mlflow_name.startswith(mapper.MLFLOW_NAME_PREFIX)
|
||||||
|
|
||||||
|
def test_to_mlflow_name_invalid(self, mapper):
|
||||||
|
"""Test conversion fails with invalid inputs."""
|
||||||
|
# Invalid prefix
|
||||||
|
with pytest.raises(ValueError, match="Invalid prompt_id format"):
|
||||||
|
mapper.to_mlflow_name("invalid_" + "a" * 48)
|
||||||
|
|
||||||
|
# Wrong length
|
||||||
|
with pytest.raises(ValueError, match="Invalid prompt_id format"):
|
||||||
|
mapper.to_mlflow_name("pmpt_" + "a" * 47)
|
||||||
|
|
||||||
|
# Invalid hex characters
|
||||||
|
with pytest.raises(ValueError, match="Invalid prompt_id format"):
|
||||||
|
mapper.to_mlflow_name("pmpt_" + "g" * 48)
|
||||||
|
|
||||||
|
def test_to_llama_id_valid(self, mapper):
|
||||||
|
"""Test converting valid MLflow name to prompt_id."""
|
||||||
|
mlflow_name = "llama_prompt_" + "b" * 48
|
||||||
|
prompt_id = mapper.to_llama_id(mlflow_name)
|
||||||
|
|
||||||
|
assert prompt_id == "pmpt_" + "b" * 48
|
||||||
|
assert prompt_id.startswith("pmpt_")
|
||||||
|
|
||||||
|
def test_to_llama_id_invalid(self, mapper):
|
||||||
|
"""Test conversion fails with invalid inputs."""
|
||||||
|
# Invalid prefix
|
||||||
|
with pytest.raises(ValueError, match="does not start with expected prefix"):
|
||||||
|
mapper.to_llama_id("wrong_prefix_" + "a" * 48)
|
||||||
|
|
||||||
|
# Wrong length
|
||||||
|
with pytest.raises(ValueError, match="Invalid hex part length"):
|
||||||
|
mapper.to_llama_id("llama_prompt_" + "a" * 47)
|
||||||
|
|
||||||
|
# Invalid hex characters
|
||||||
|
with pytest.raises(ValueError, match="Invalid character"):
|
||||||
|
mapper.to_llama_id("llama_prompt_" + "G" * 48)
|
||||||
|
|
||||||
|
def test_bidirectional_conversion(self, mapper):
|
||||||
|
"""Test bidirectional conversion preserves IDs."""
|
||||||
|
original_id = "pmpt_0123456789abcdef" + "a" * 32
|
||||||
|
|
||||||
|
# Convert to MLflow name and back
|
||||||
|
mlflow_name = mapper.to_mlflow_name(original_id)
|
||||||
|
recovered_id = mapper.to_llama_id(mlflow_name)
|
||||||
|
|
||||||
|
assert recovered_id == original_id
|
||||||
|
|
||||||
|
def test_get_metadata_tags_with_variables(self, mapper):
|
||||||
|
"""Test metadata tags generation with variables."""
|
||||||
|
prompt_id = "pmpt_" + "c" * 48
|
||||||
|
variables = ["var1", "var2", "var3"]
|
||||||
|
|
||||||
|
tags = mapper.get_metadata_tags(prompt_id, variables)
|
||||||
|
|
||||||
|
assert tags["llama_stack_id"] == prompt_id
|
||||||
|
assert tags["llama_stack_managed"] == "true"
|
||||||
|
assert tags["variables"] == "var1,var2,var3"
|
||||||
|
|
||||||
|
def test_get_metadata_tags_without_variables(self, mapper):
|
||||||
|
"""Test metadata tags generation without variables."""
|
||||||
|
prompt_id = "pmpt_" + "d" * 48
|
||||||
|
|
||||||
|
tags = mapper.get_metadata_tags(prompt_id)
|
||||||
|
|
||||||
|
assert tags["llama_stack_id"] == prompt_id
|
||||||
|
assert tags["llama_stack_managed"] == "true"
|
||||||
|
assert "variables" not in tags
|
||||||
Loading…
Add table
Add a link
Reference in a new issue