mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Add a new remote provider that integrates MLflow's Prompt Registry with Llama Stack's prompts API, enabling centralized prompt management and versioning using MLflow as the backend. Features: - Full implementation of Llama Stack Prompts protocol - Support for prompt versioning and default version management - Automatic variable extraction from Jinja2-style templates - MLflow tag-based metadata for efficient prompt filtering - Flexible authentication (config, environment variables, per-request) - Bidirectional ID mapping (pmpt_<hex> ↔ llama_prompt_<hex>) - Comprehensive error handling and validation Implementation: - Remote provider: src/llama_stack/providers/remote/prompts/mlflow/ - Inline reference provider: src/llama_stack/providers/inline/prompts/reference/ - MLflow 3.4+ required for Prompt Registry API support - Deterministic ID mapping ensures consistency across conversions Testing: - 15 comprehensive unit tests (config validation, ID mapping) - 18 end-to-end integration tests (full CRUD workflows) - GitHub Actions workflow for automated CI testing with MLflow server - Integration test fixtures with automatic server setup Documentation: - Complete provider configuration reference - Setup and usage examples with code samples - Authentication options and security best practices Signed-off-by: William Caban <william.caban@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com>
138 lines
5.2 KiB
Python
138 lines
5.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
"""Unit tests for MLflow prompts provider configuration."""
|
|
|
|
import pytest
|
|
from pydantic import SecretStr, ValidationError
|
|
|
|
from llama_stack.providers.remote.prompts.mlflow.config import (
|
|
MLflowPromptsConfig,
|
|
MLflowProviderDataValidator,
|
|
)
|
|
|
|
|
|
class TestMLflowPromptsConfig:
|
|
"""Tests for MLflowPromptsConfig model."""
|
|
|
|
def test_default_config(self):
|
|
"""Test default configuration values."""
|
|
config = MLflowPromptsConfig()
|
|
|
|
assert config.mlflow_tracking_uri == "http://localhost:5000"
|
|
assert config.mlflow_registry_uri is None
|
|
assert config.experiment_name == "llama-stack-prompts"
|
|
assert config.auth_credential is None
|
|
assert config.timeout_seconds == 30
|
|
|
|
def test_custom_config(self):
|
|
"""Test custom configuration values."""
|
|
config = MLflowPromptsConfig(
|
|
mlflow_tracking_uri="http://mlflow.example.com:8080",
|
|
mlflow_registry_uri="http://registry.example.com:8080",
|
|
experiment_name="my-prompts",
|
|
auth_credential=SecretStr("my-token"),
|
|
timeout_seconds=60,
|
|
)
|
|
|
|
assert config.mlflow_tracking_uri == "http://mlflow.example.com:8080"
|
|
assert config.mlflow_registry_uri == "http://registry.example.com:8080"
|
|
assert config.experiment_name == "my-prompts"
|
|
assert config.auth_credential.get_secret_value() == "my-token"
|
|
assert config.timeout_seconds == 60
|
|
|
|
def test_databricks_uri(self):
|
|
"""Test Databricks URI configuration."""
|
|
config = MLflowPromptsConfig(
|
|
mlflow_tracking_uri="databricks",
|
|
mlflow_registry_uri="databricks://profile",
|
|
)
|
|
|
|
assert config.mlflow_tracking_uri == "databricks"
|
|
assert config.mlflow_registry_uri == "databricks://profile"
|
|
|
|
def test_tracking_uri_validation(self):
|
|
"""Test tracking URI validation."""
|
|
# Empty string rejected
|
|
with pytest.raises(ValidationError, match="mlflow_tracking_uri cannot be empty"):
|
|
MLflowPromptsConfig(mlflow_tracking_uri="")
|
|
|
|
# Whitespace-only rejected
|
|
with pytest.raises(ValidationError, match="mlflow_tracking_uri cannot be empty"):
|
|
MLflowPromptsConfig(mlflow_tracking_uri=" ")
|
|
|
|
# Whitespace is stripped
|
|
config = MLflowPromptsConfig(mlflow_tracking_uri=" http://localhost:5000 ")
|
|
assert config.mlflow_tracking_uri == "http://localhost:5000"
|
|
|
|
def test_experiment_name_validation(self):
|
|
"""Test experiment name validation."""
|
|
# Empty string rejected
|
|
with pytest.raises(ValidationError, match="experiment_name cannot be empty"):
|
|
MLflowPromptsConfig(experiment_name="")
|
|
|
|
# Whitespace-only rejected
|
|
with pytest.raises(ValidationError, match="experiment_name cannot be empty"):
|
|
MLflowPromptsConfig(experiment_name=" ")
|
|
|
|
# Whitespace is stripped
|
|
config = MLflowPromptsConfig(experiment_name=" my-experiment ")
|
|
assert config.experiment_name == "my-experiment"
|
|
|
|
def test_timeout_validation(self):
|
|
"""Test timeout range validation."""
|
|
# Too low rejected
|
|
with pytest.raises(ValidationError):
|
|
MLflowPromptsConfig(timeout_seconds=0)
|
|
|
|
with pytest.raises(ValidationError):
|
|
MLflowPromptsConfig(timeout_seconds=-1)
|
|
|
|
# Too high rejected
|
|
with pytest.raises(ValidationError):
|
|
MLflowPromptsConfig(timeout_seconds=301)
|
|
|
|
# Boundary values accepted
|
|
config_min = MLflowPromptsConfig(timeout_seconds=1)
|
|
assert config_min.timeout_seconds == 1
|
|
|
|
config_max = MLflowPromptsConfig(timeout_seconds=300)
|
|
assert config_max.timeout_seconds == 300
|
|
|
|
def test_sample_run_config(self):
|
|
"""Test sample_run_config generates valid configuration."""
|
|
# Default environment variable
|
|
sample = MLflowPromptsConfig.sample_run_config()
|
|
assert sample["mlflow_tracking_uri"] == "http://localhost:5000"
|
|
assert sample["experiment_name"] == "llama-stack-prompts"
|
|
assert sample["auth_credential"] == "${env.MLFLOW_TRACKING_TOKEN:=}"
|
|
|
|
# Custom values
|
|
sample = MLflowPromptsConfig.sample_run_config(
|
|
mlflow_api_token="test-token",
|
|
mlflow_tracking_uri="http://custom:5000",
|
|
)
|
|
assert sample["mlflow_tracking_uri"] == "http://custom:5000"
|
|
assert sample["auth_credential"] == "test-token"
|
|
|
|
|
|
class TestMLflowProviderDataValidator:
|
|
"""Tests for MLflowProviderDataValidator."""
|
|
|
|
def test_provider_data_validator(self):
|
|
"""Test provider data validator with and without token."""
|
|
# With token
|
|
validator = MLflowProviderDataValidator(mlflow_api_token="test-token-123")
|
|
assert validator.mlflow_api_token == "test-token-123"
|
|
|
|
# Without token
|
|
validator = MLflowProviderDataValidator()
|
|
assert validator.mlflow_api_token is None
|
|
|
|
# From dictionary
|
|
data = {"mlflow_api_token": "secret-token"}
|
|
validator = MLflowProviderDataValidator(**data)
|
|
assert validator.mlflow_api_token == "secret-token"
|