mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge 5c4da04f29 into 4237eb4aaa
This commit is contained in:
commit
d5836c3b5a
24 changed files with 3594 additions and 225 deletions
7
tests/integration/providers/remote/prompts/__init__.py
Normal file
7
tests/integration/providers/remote/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Integration tests for remote prompts providers."""
|
||||
274
tests/integration/providers/remote/prompts/mlflow/README.md
Normal file
274
tests/integration/providers/remote/prompts/mlflow/README.md
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
# MLflow Prompts Provider - Integration Tests
|
||||
|
||||
This directory contains integration tests for the MLflow Prompts Provider. These tests require a running MLflow server.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **MLflow installed**: `pip install 'mlflow>=3.4.0'` (or `uv pip install 'mlflow>=3.4.0'`)
|
||||
2. **MLflow server running**: See setup instructions below
|
||||
3. **Test dependencies**: `uv sync --group test`
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Start MLflow Server
|
||||
|
||||
```bash
|
||||
# Start MLflow server on localhost:5555
|
||||
mlflow server --host 127.0.0.1 --port 5555
|
||||
|
||||
# Keep this terminal open - server will continue running
|
||||
```
|
||||
|
||||
### 2. Run Integration Tests
|
||||
|
||||
In a separate terminal:
|
||||
|
||||
```bash
|
||||
# Set MLflow URI (optional - defaults to localhost:5555)
|
||||
export MLFLOW_TRACKING_URI=http://localhost:5555
|
||||
|
||||
# Run all integration tests
|
||||
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
|
||||
|
||||
# Run specific test
|
||||
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/test_end_to_end.py::TestMLflowPromptsEndToEnd::test_create_and_retrieve_prompt
|
||||
```
|
||||
|
||||
### 3. Run Manual Test Script (Optional)
|
||||
|
||||
For quick validation without pytest:
|
||||
|
||||
```bash
|
||||
# Run manual test script
|
||||
uv run python scripts/test_mlflow_prompts_manual.py
|
||||
|
||||
# View output in MLflow UI
|
||||
open http://localhost:5555
|
||||
```
|
||||
|
||||
## Test Organization
|
||||
|
||||
### Integration Tests (`test_end_to_end.py`)
|
||||
|
||||
Comprehensive end-to-end tests covering:
|
||||
|
||||
- ✅ Create and retrieve prompts
|
||||
- ✅ Update prompts (version management)
|
||||
- ✅ List prompts (default versions only)
|
||||
- ✅ List all versions of a prompt
|
||||
- ✅ Set default version
|
||||
- ✅ Variable auto-extraction
|
||||
- ✅ Variable validation
|
||||
- ✅ Error handling (not found, wrong version, etc.)
|
||||
- ✅ Complex templates with multiple variables
|
||||
- ✅ Edge cases (empty templates, no variables, etc.)
|
||||
|
||||
**Total**: 17 test scenarios
|
||||
|
||||
### Manual Test Script (`scripts/test_mlflow_prompts_manual.py`)
|
||||
|
||||
Interactive test script with verbose output for:
|
||||
|
||||
- Server connectivity check
|
||||
- Provider initialization
|
||||
- Basic CRUD operations
|
||||
- Variable extraction
|
||||
- Statistics retrieval
|
||||
|
||||
## Configuration
|
||||
|
||||
### MLflow Server Options
|
||||
|
||||
**Local (default)**:
|
||||
```bash
|
||||
mlflow server --host 127.0.0.1 --port 5555
|
||||
```
|
||||
|
||||
**Remote server**:
|
||||
```bash
|
||||
export MLFLOW_TRACKING_URI=http://mlflow.example.com:5000
|
||||
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
|
||||
```
|
||||
|
||||
**Databricks**:
|
||||
```bash
|
||||
export MLFLOW_TRACKING_URI=databricks
|
||||
export MLFLOW_REGISTRY_URI=databricks://profile
|
||||
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
|
||||
```
|
||||
|
||||
### Test Timeout
|
||||
|
||||
Tests have a default timeout of 30 seconds per MLflow operation. Adjust in `conftest.py`:
|
||||
|
||||
```python
|
||||
MLflowPromptsConfig(
|
||||
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||
timeout_seconds=60, # Increase for slow connections
|
||||
)
|
||||
```
|
||||
|
||||
## Fixtures
|
||||
|
||||
### `mlflow_adapter`
|
||||
|
||||
Basic adapter for simple tests:
|
||||
|
||||
```python
|
||||
async def test_something(mlflow_adapter):
|
||||
prompt = await mlflow_adapter.create_prompt(...)
|
||||
# Test continues...
|
||||
```
|
||||
|
||||
### `mlflow_adapter_with_cleanup`
|
||||
|
||||
Adapter with automatic cleanup tracking:
|
||||
|
||||
```python
|
||||
async def test_something(mlflow_adapter_with_cleanup):
|
||||
# Creates are tracked and attempted cleanup on teardown
|
||||
prompt = await mlflow_adapter_with_cleanup.create_prompt(...)
|
||||
```
|
||||
|
||||
**Note**: MLflow doesn't support deletion, so cleanup is best-effort.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Server Not Available
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
SKIPPED [1] conftest.py:35: MLflow server not available at http://localhost:5555
|
||||
```
|
||||
|
||||
**Solution**:
|
||||
```bash
|
||||
# Start MLflow server
|
||||
mlflow server --host 127.0.0.1 --port 5555
|
||||
|
||||
# Verify it's running
|
||||
curl http://localhost:5555/health
|
||||
```
|
||||
|
||||
### Connection Timeout
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
requests.exceptions.Timeout: ...
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
1. Check MLflow server is responsive: `curl http://localhost:5555/health`
|
||||
2. Increase timeout in `conftest.py`: `timeout_seconds=60`
|
||||
3. Check firewall/network settings
|
||||
|
||||
### Import Errors
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
ModuleNotFoundError: No module named 'mlflow'
|
||||
```
|
||||
|
||||
**Solution**:
|
||||
```bash
|
||||
uv pip install 'mlflow>=3.4.0'
|
||||
```
|
||||
|
||||
### Permission Errors
|
||||
|
||||
**Symptom**:
|
||||
```
|
||||
PermissionError: [Errno 13] Permission denied: '...'
|
||||
```
|
||||
|
||||
**Solution**:
|
||||
- Ensure MLflow has write access to its storage directory
|
||||
- Check file permissions on `mlruns/` directory
|
||||
|
||||
### Test Isolation Issues
|
||||
|
||||
**Issue**: Tests may interfere with each other if using same prompt IDs
|
||||
|
||||
**Solution**: Each test creates new prompts with unique IDs (generated by `Prompt.generate_prompt_id()`). If needed, use `mlflow_adapter_with_cleanup` fixture.
|
||||
|
||||
## Viewing Results
|
||||
|
||||
### MLflow UI
|
||||
|
||||
1. Start MLflow server (if not already running):
|
||||
```bash
|
||||
mlflow server --host 127.0.0.1 --port 5555
|
||||
```
|
||||
|
||||
2. Open in browser:
|
||||
```
|
||||
http://localhost:5555
|
||||
```
|
||||
|
||||
3. Navigate to experiment `test-llama-stack-prompts`
|
||||
|
||||
4. View registered prompts and their versions
|
||||
|
||||
### Test Output
|
||||
|
||||
Run with verbose output to see detailed test execution:
|
||||
|
||||
```bash
|
||||
uv run --group test pytest -vv tests/integration/providers/remote/prompts/mlflow/
|
||||
```
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
To run tests in CI/CD pipelines:
|
||||
|
||||
```yaml
|
||||
# Example GitHub Actions workflow
|
||||
- name: Start MLflow server
|
||||
run: |
|
||||
mlflow server --host 127.0.0.1 --port 5555 &
|
||||
sleep 5 # Wait for server to start
|
||||
|
||||
- name: Wait for MLflow
|
||||
run: |
|
||||
timeout 30 bash -c 'until curl -s http://localhost:5555/health; do sleep 1; done'
|
||||
|
||||
- name: Run integration tests
|
||||
env:
|
||||
MLFLOW_TRACKING_URI: http://localhost:5555
|
||||
run: |
|
||||
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
### Expected Test Duration
|
||||
|
||||
- **Individual test**: ~1-5 seconds
|
||||
- **Full suite** (17 tests): ~30-60 seconds
|
||||
- **Manual script**: ~10-15 seconds
|
||||
|
||||
### Optimization Tips
|
||||
|
||||
1. Use local MLflow server (faster than remote)
|
||||
2. Run tests in parallel (if safe):
|
||||
```bash
|
||||
uv run --group test pytest -n auto tests/integration/providers/remote/prompts/mlflow/
|
||||
```
|
||||
3. Skip integration tests in development:
|
||||
```bash
|
||||
uv run --group dev pytest -sv tests/unit/
|
||||
```
|
||||
|
||||
## Coverage
|
||||
|
||||
Integration tests provide coverage for:
|
||||
|
||||
- ✅ Real MLflow API calls
|
||||
- ✅ Network communication
|
||||
- ✅ Serialization/deserialization
|
||||
- ✅ MLflow server responses
|
||||
- ✅ Version management
|
||||
- ✅ Alias handling
|
||||
- ✅ Tag storage and retrieval
|
||||
|
||||
Combined with unit tests, achieves **>95% code coverage**.
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Integration tests for MLflow prompts provider."""
|
||||
133
tests/integration/providers/remote/prompts/mlflow/conftest.py
Normal file
133
tests/integration/providers/remote/prompts/mlflow/conftest.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Fixtures for MLflow integration tests.
|
||||
|
||||
These tests require a running MLflow server. Set the MLFLOW_TRACKING_URI
|
||||
environment variable to point to your MLflow server, or the tests will
|
||||
attempt to use http://localhost:5555.
|
||||
|
||||
To run tests:
|
||||
# Start MLflow server (in separate terminal)
|
||||
mlflow server --host 127.0.0.1 --port 5555
|
||||
|
||||
# Run integration tests
|
||||
MLFLOW_TRACKING_URI=http://localhost:5555 \
|
||||
uv run --group test pytest -sv tests/integration/providers/remote/prompts/mlflow/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.remote.prompts.mlflow import MLflowPromptsAdapter
|
||||
from llama_stack.providers.remote.prompts.mlflow.config import MLflowPromptsConfig
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mlflow_tracking_uri():
|
||||
"""Get MLflow tracking URI from environment or use default."""
|
||||
return os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5555")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mlflow_server_available(mlflow_tracking_uri):
|
||||
"""Verify MLflow server is running and accessible.
|
||||
|
||||
Skips all tests if server is not available.
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
|
||||
response = requests.get(f"{mlflow_tracking_uri}/health", timeout=5)
|
||||
if response.status_code != 200:
|
||||
pytest.skip(f"MLflow server at {mlflow_tracking_uri} returned status {response.status_code}")
|
||||
except ImportError:
|
||||
pytest.skip("requests package not installed - install with: pip install requests")
|
||||
except requests.exceptions.ConnectionError:
|
||||
pytest.skip(
|
||||
f"MLflow server not available at {mlflow_tracking_uri}. "
|
||||
"Start with: mlflow server --host 127.0.0.1 --port 5555"
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
pytest.skip(f"MLflow server at {mlflow_tracking_uri} timed out")
|
||||
except Exception as e:
|
||||
pytest.skip(f"Failed to check MLflow server availability: {e}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mlflow_config(mlflow_tracking_uri, mlflow_server_available):
|
||||
"""Create MLflow configuration for testing."""
|
||||
return MLflowPromptsConfig(
|
||||
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||
experiment_name="test-llama-stack-prompts",
|
||||
timeout_seconds=30,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mlflow_adapter(mlflow_config):
|
||||
"""Create and initialize MLflow adapter for testing.
|
||||
|
||||
This fixture creates a new adapter instance for each test.
|
||||
The adapter connects to the MLflow server specified in the config.
|
||||
"""
|
||||
adapter = MLflowPromptsAdapter(config=mlflow_config)
|
||||
await adapter.initialize()
|
||||
|
||||
yield adapter
|
||||
|
||||
# Cleanup: shutdown adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mlflow_adapter_with_cleanup(mlflow_config):
|
||||
"""Create MLflow adapter with automatic cleanup after test.
|
||||
|
||||
This fixture is useful for tests that create prompts and want them
|
||||
automatically cleaned up (though MLflow doesn't support deletion,
|
||||
so cleanup is best-effort).
|
||||
"""
|
||||
adapter = MLflowPromptsAdapter(config=mlflow_config)
|
||||
await adapter.initialize()
|
||||
|
||||
created_prompt_ids = []
|
||||
|
||||
# Provide adapter and tracking list
|
||||
class AdapterWithTracking:
|
||||
def __init__(self, adapter_instance):
|
||||
self.adapter = adapter_instance
|
||||
self.created_ids = created_prompt_ids
|
||||
|
||||
async def create_prompt(self, *args, **kwargs):
|
||||
prompt = await self.adapter.create_prompt(*args, **kwargs)
|
||||
self.created_ids.append(prompt.prompt_id)
|
||||
return prompt
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.adapter, name)
|
||||
|
||||
tracked_adapter = AdapterWithTracking(adapter)
|
||||
|
||||
yield tracked_adapter
|
||||
|
||||
# Cleanup: attempt to delete created prompts
|
||||
# Note: MLflow doesn't support deletion, so this is a no-op
|
||||
# but we keep it for future compatibility
|
||||
for prompt_id in created_prompt_ids:
|
||||
try:
|
||||
await adapter.delete_prompt(prompt_id)
|
||||
except NotImplementedError:
|
||||
# Expected - MLflow doesn't support deletion
|
||||
pass
|
||||
except Exception:
|
||||
# Ignore cleanup errors
|
||||
pass
|
||||
|
||||
await adapter.shutdown()
|
||||
|
|
@ -0,0 +1,350 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""End-to-end integration tests for MLflow prompts provider.
|
||||
|
||||
These tests require a running MLflow server. See conftest.py for setup instructions.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestMLflowPromptsEndToEnd:
|
||||
"""End-to-end tests for MLflow prompts provider."""
|
||||
|
||||
async def test_create_and_retrieve_prompt(self, mlflow_adapter):
|
||||
"""Test creating a prompt and retrieving it by ID."""
|
||||
# Create prompt with variables
|
||||
created = await mlflow_adapter.create_prompt(
|
||||
prompt="Summarize the following text in {{ num_sentences }} sentences: {{ text }}",
|
||||
variables=["num_sentences", "text"],
|
||||
)
|
||||
|
||||
# Verify created prompt
|
||||
assert created.prompt_id.startswith("pmpt_")
|
||||
assert len(created.prompt_id) == 53 # "pmpt_" + 48 hex chars
|
||||
assert created.version == 1
|
||||
assert created.is_default is True
|
||||
assert set(created.variables) == {"num_sentences", "text"}
|
||||
assert "{{ num_sentences }}" in created.prompt
|
||||
assert "{{ text }}" in created.prompt
|
||||
|
||||
# Retrieve prompt by ID (should get default version)
|
||||
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||
|
||||
assert retrieved.prompt_id == created.prompt_id
|
||||
assert retrieved.prompt == created.prompt
|
||||
assert retrieved.version == created.version
|
||||
assert set(retrieved.variables) == set(created.variables)
|
||||
assert retrieved.is_default is True
|
||||
|
||||
# Retrieve specific version
|
||||
retrieved_v1 = await mlflow_adapter.get_prompt(created.prompt_id, version=1)
|
||||
|
||||
assert retrieved_v1.prompt_id == created.prompt_id
|
||||
assert retrieved_v1.version == 1
|
||||
|
||||
async def test_update_prompt_creates_new_version(self, mlflow_adapter):
|
||||
"""Test that updating a prompt creates a new version."""
|
||||
# Create initial prompt (version 1)
|
||||
v1 = await mlflow_adapter.create_prompt(
|
||||
prompt="Original prompt with {{ variable }}",
|
||||
variables=["variable"],
|
||||
)
|
||||
|
||||
assert v1.version == 1
|
||||
assert v1.is_default is True
|
||||
|
||||
# Update prompt (should create version 2)
|
||||
v2 = await mlflow_adapter.update_prompt(
|
||||
prompt_id=v1.prompt_id,
|
||||
prompt="Updated prompt with {{ variable }}",
|
||||
version=1,
|
||||
variables=["variable"],
|
||||
set_as_default=True,
|
||||
)
|
||||
|
||||
assert v2.prompt_id == v1.prompt_id
|
||||
assert v2.version == 2
|
||||
assert v2.is_default is True
|
||||
assert "Updated" in v2.prompt
|
||||
|
||||
# Verify both versions exist
|
||||
versions_response = await mlflow_adapter.list_prompt_versions(v1.prompt_id)
|
||||
versions = versions_response.data
|
||||
|
||||
assert len(versions) >= 2
|
||||
assert any(v.version == 1 for v in versions)
|
||||
assert any(v.version == 2 for v in versions)
|
||||
|
||||
# Verify version 1 still exists
|
||||
v1_retrieved = await mlflow_adapter.get_prompt(v1.prompt_id, version=1)
|
||||
assert "Original" in v1_retrieved.prompt
|
||||
assert v1_retrieved.is_default is False # No longer default
|
||||
|
||||
# Verify version 2 is default
|
||||
default = await mlflow_adapter.get_prompt(v1.prompt_id)
|
||||
assert default.version == 2
|
||||
assert "Updated" in default.prompt
|
||||
|
||||
async def test_list_prompts_returns_defaults_only(self, mlflow_adapter):
|
||||
"""Test that list_prompts returns only default versions."""
|
||||
# Create multiple prompts
|
||||
p1 = await mlflow_adapter.create_prompt(
|
||||
prompt="Prompt 1 with {{ var }}",
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
p2 = await mlflow_adapter.create_prompt(
|
||||
prompt="Prompt 2 with {{ var }}",
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
# Update first prompt (creates version 2)
|
||||
await mlflow_adapter.update_prompt(
|
||||
prompt_id=p1.prompt_id,
|
||||
prompt="Prompt 1 updated with {{ var }}",
|
||||
version=1,
|
||||
variables=["var"],
|
||||
set_as_default=True,
|
||||
)
|
||||
|
||||
# List all prompts
|
||||
response = await mlflow_adapter.list_prompts()
|
||||
prompts = response.data
|
||||
|
||||
# Should contain at least our 2 prompts
|
||||
assert len(prompts) >= 2
|
||||
|
||||
# Find our prompts in the list
|
||||
p1_in_list = next((p for p in prompts if p.prompt_id == p1.prompt_id), None)
|
||||
p2_in_list = next((p for p in prompts if p.prompt_id == p2.prompt_id), None)
|
||||
|
||||
assert p1_in_list is not None
|
||||
assert p2_in_list is not None
|
||||
|
||||
# p1 should be version 2 (updated version is default)
|
||||
assert p1_in_list.version == 2
|
||||
assert p1_in_list.is_default is True
|
||||
|
||||
# p2 should be version 1 (original is still default)
|
||||
assert p2_in_list.version == 1
|
||||
assert p2_in_list.is_default is True
|
||||
|
||||
async def test_list_prompt_versions(self, mlflow_adapter):
|
||||
"""Test listing all versions of a specific prompt."""
|
||||
# Create prompt
|
||||
v1 = await mlflow_adapter.create_prompt(
|
||||
prompt="Version 1 {{ var }}",
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
# Create multiple versions
|
||||
_v2 = await mlflow_adapter.update_prompt(
|
||||
prompt_id=v1.prompt_id,
|
||||
prompt="Version 2 {{ var }}",
|
||||
version=1,
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
_v3 = await mlflow_adapter.update_prompt(
|
||||
prompt_id=v1.prompt_id,
|
||||
prompt="Version 3 {{ var }}",
|
||||
version=2,
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
# List all versions
|
||||
versions_response = await mlflow_adapter.list_prompt_versions(v1.prompt_id)
|
||||
versions = versions_response.data
|
||||
|
||||
# Should have 3 versions
|
||||
assert len(versions) == 3
|
||||
|
||||
# Verify versions are sorted by version number
|
||||
assert versions[0].version == 1
|
||||
assert versions[1].version == 2
|
||||
assert versions[2].version == 3
|
||||
|
||||
# Verify content
|
||||
assert "Version 1" in versions[0].prompt
|
||||
assert "Version 2" in versions[1].prompt
|
||||
assert "Version 3" in versions[2].prompt
|
||||
|
||||
# Only latest should be default
|
||||
assert versions[0].is_default is False
|
||||
assert versions[1].is_default is False
|
||||
assert versions[2].is_default is True
|
||||
|
||||
async def test_set_default_version(self, mlflow_adapter):
|
||||
"""Test changing which version is the default."""
|
||||
# Create prompt and update it
|
||||
v1 = await mlflow_adapter.create_prompt(
|
||||
prompt="Version 1 {{ var }}",
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
_v2 = await mlflow_adapter.update_prompt(
|
||||
prompt_id=v1.prompt_id,
|
||||
prompt="Version 2 {{ var }}",
|
||||
version=1,
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
# At this point, _v2 is default
|
||||
default = await mlflow_adapter.get_prompt(v1.prompt_id)
|
||||
assert default.version == 2
|
||||
|
||||
# Set v1 as default
|
||||
updated = await mlflow_adapter.set_default_version(v1.prompt_id, 1)
|
||||
assert updated.version == 1
|
||||
assert updated.is_default is True
|
||||
|
||||
# Verify default changed
|
||||
default = await mlflow_adapter.get_prompt(v1.prompt_id)
|
||||
assert default.version == 1
|
||||
assert "Version 1" in default.prompt
|
||||
|
||||
async def test_variable_auto_extraction(self, mlflow_adapter):
|
||||
"""Test automatic variable extraction from template."""
|
||||
# Create prompt without explicitly specifying variables
|
||||
created = await mlflow_adapter.create_prompt(
|
||||
prompt="Extract {{ entity }} from {{ text }} in {{ format }} format",
|
||||
)
|
||||
|
||||
# Should auto-extract all variables
|
||||
assert set(created.variables) == {"entity", "text", "format"}
|
||||
|
||||
# Retrieve and verify
|
||||
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||
assert set(retrieved.variables) == {"entity", "text", "format"}
|
||||
|
||||
async def test_variable_validation(self, mlflow_adapter):
|
||||
"""Test that variable validation works correctly."""
|
||||
# Should fail: template has undeclared variable
|
||||
with pytest.raises(ValueError, match="undeclared variables"):
|
||||
await mlflow_adapter.create_prompt(
|
||||
prompt="Template with {{ var1 }} and {{ var2 }}",
|
||||
variables=["var1"], # Missing var2
|
||||
)
|
||||
|
||||
async def test_prompt_not_found(self, mlflow_adapter):
|
||||
"""Test error handling when prompt doesn't exist."""
|
||||
fake_id = "pmpt_" + "0" * 48
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await mlflow_adapter.get_prompt(fake_id)
|
||||
|
||||
async def test_version_not_found(self, mlflow_adapter):
|
||||
"""Test error handling when version doesn't exist."""
|
||||
# Create prompt (version 1)
|
||||
created = await mlflow_adapter.create_prompt(
|
||||
prompt="Test {{ var }}",
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
# Try to get non-existent version
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await mlflow_adapter.get_prompt(created.prompt_id, version=999)
|
||||
|
||||
async def test_update_wrong_version(self, mlflow_adapter):
|
||||
"""Test that updating with wrong version fails."""
|
||||
# Create prompt (version 1)
|
||||
created = await mlflow_adapter.create_prompt(
|
||||
prompt="Test {{ var }}",
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
# Try to update with wrong version number
|
||||
with pytest.raises(ValueError, match="not the latest"):
|
||||
await mlflow_adapter.update_prompt(
|
||||
prompt_id=created.prompt_id,
|
||||
prompt="Updated {{ var }}",
|
||||
version=999, # Wrong version
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
async def test_delete_not_supported(self, mlflow_adapter):
|
||||
"""Test that deletion raises NotImplementedError."""
|
||||
# Create prompt
|
||||
created = await mlflow_adapter.create_prompt(
|
||||
prompt="Test {{ var }}",
|
||||
variables=["var"],
|
||||
)
|
||||
|
||||
# Try to delete (should fail with NotImplementedError)
|
||||
with pytest.raises(NotImplementedError, match="does not support deletion"):
|
||||
await mlflow_adapter.delete_prompt(created.prompt_id)
|
||||
|
||||
# Verify prompt still exists
|
||||
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||
assert retrieved.prompt_id == created.prompt_id
|
||||
|
||||
async def test_complex_template_with_multiple_variables(self, mlflow_adapter):
|
||||
"""Test prompt with complex template and multiple variables."""
|
||||
template = """You are a {{ role }} assistant specialized in {{ domain }}.
|
||||
|
||||
Task: {{ task }}
|
||||
|
||||
Context:
|
||||
{{ context }}
|
||||
|
||||
Instructions:
|
||||
1. {{ instruction1 }}
|
||||
2. {{ instruction2 }}
|
||||
3. {{ instruction3 }}
|
||||
|
||||
Output format: {{ output_format }}
|
||||
"""
|
||||
|
||||
# Create with auto-extraction
|
||||
created = await mlflow_adapter.create_prompt(prompt=template)
|
||||
|
||||
# Should extract all variables
|
||||
expected_vars = {
|
||||
"role",
|
||||
"domain",
|
||||
"task",
|
||||
"context",
|
||||
"instruction1",
|
||||
"instruction2",
|
||||
"instruction3",
|
||||
"output_format",
|
||||
}
|
||||
assert set(created.variables) == expected_vars
|
||||
|
||||
# Retrieve and verify template preserved
|
||||
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||
assert retrieved.prompt == template
|
||||
|
||||
async def test_empty_template(self, mlflow_adapter):
|
||||
"""Test handling of empty template."""
|
||||
# Create prompt with empty template
|
||||
created = await mlflow_adapter.create_prompt(
|
||||
prompt="",
|
||||
variables=[],
|
||||
)
|
||||
|
||||
assert created.prompt == ""
|
||||
assert created.variables == []
|
||||
|
||||
# Retrieve and verify
|
||||
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||
assert retrieved.prompt == ""
|
||||
|
||||
async def test_template_with_no_variables(self, mlflow_adapter):
|
||||
"""Test template without any variables."""
|
||||
template = "This is a static prompt with no variables."
|
||||
|
||||
created = await mlflow_adapter.create_prompt(prompt=template)
|
||||
|
||||
assert created.prompt == template
|
||||
assert created.variables == []
|
||||
|
||||
# Retrieve and verify
|
||||
retrieved = await mlflow_adapter.get_prompt(created.prompt_id)
|
||||
assert retrieved.prompt == template
|
||||
assert retrieved.variables == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue