mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: Add MLflow Prompt Registry provider (squashed commit)
Add a new remote provider that integrates MLflow's Prompt Registry with Llama Stack's prompts API, enabling centralized prompt management and versioning using MLflow as the backend. Features: - Full implementation of Llama Stack Prompts protocol - Support for prompt versioning and default version management - Automatic variable extraction from Jinja2-style templates - MLflow tag-based metadata for efficient prompt filtering - Flexible authentication (config, environment variables, per-request) - Bidirectional ID mapping (pmpt_<hex> ↔ llama_prompt_<hex>) - Comprehensive error handling and validation Implementation: - Remote provider: src/llama_stack/providers/remote/prompts/mlflow/ - Inline reference provider: src/llama_stack/providers/inline/prompts/reference/ - MLflow 3.4+ required for Prompt Registry API support - Deterministic ID mapping ensures consistency across conversions Testing: - 15 comprehensive unit tests (config validation, ID mapping) - 18 end-to-end integration tests (full CRUD workflows) - GitHub Actions workflow for automated CI testing with MLflow server - Integration test fixtures with automatic server setup Documentation: - Complete provider configuration reference - Setup and usage examples with code samples - Authentication options and security best practices Signed-off-by: William Caban <william.caban@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
aac494c5ba
commit
0e0d311dea
24 changed files with 3594 additions and 225 deletions
95
tests/unit/providers/remote/prompts/mlflow/test_mapping.py
Normal file
95
tests/unit/providers/remote/prompts/mlflow/test_mapping.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for MLflow prompts ID mapping utilities."""
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.remote.prompts.mlflow.mapping import PromptIDMapper
|
||||
|
||||
|
||||
class TestPromptIDMapper:
|
||||
"""Tests for PromptIDMapper class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mapper(self):
|
||||
"""Create ID mapper instance."""
|
||||
return PromptIDMapper()
|
||||
|
||||
def test_to_mlflow_name_valid(self, mapper):
|
||||
"""Test converting valid prompt_id to MLflow name."""
|
||||
prompt_id = "pmpt_" + "a" * 48
|
||||
mlflow_name = mapper.to_mlflow_name(prompt_id)
|
||||
|
||||
assert mlflow_name == "llama_prompt_" + "a" * 48
|
||||
assert mlflow_name.startswith(mapper.MLFLOW_NAME_PREFIX)
|
||||
|
||||
def test_to_mlflow_name_invalid(self, mapper):
|
||||
"""Test conversion fails with invalid inputs."""
|
||||
# Invalid prefix
|
||||
with pytest.raises(ValueError, match="Invalid prompt_id format"):
|
||||
mapper.to_mlflow_name("invalid_" + "a" * 48)
|
||||
|
||||
# Wrong length
|
||||
with pytest.raises(ValueError, match="Invalid prompt_id format"):
|
||||
mapper.to_mlflow_name("pmpt_" + "a" * 47)
|
||||
|
||||
# Invalid hex characters
|
||||
with pytest.raises(ValueError, match="Invalid prompt_id format"):
|
||||
mapper.to_mlflow_name("pmpt_" + "g" * 48)
|
||||
|
||||
def test_to_llama_id_valid(self, mapper):
|
||||
"""Test converting valid MLflow name to prompt_id."""
|
||||
mlflow_name = "llama_prompt_" + "b" * 48
|
||||
prompt_id = mapper.to_llama_id(mlflow_name)
|
||||
|
||||
assert prompt_id == "pmpt_" + "b" * 48
|
||||
assert prompt_id.startswith("pmpt_")
|
||||
|
||||
def test_to_llama_id_invalid(self, mapper):
|
||||
"""Test conversion fails with invalid inputs."""
|
||||
# Invalid prefix
|
||||
with pytest.raises(ValueError, match="does not start with expected prefix"):
|
||||
mapper.to_llama_id("wrong_prefix_" + "a" * 48)
|
||||
|
||||
# Wrong length
|
||||
with pytest.raises(ValueError, match="Invalid hex part length"):
|
||||
mapper.to_llama_id("llama_prompt_" + "a" * 47)
|
||||
|
||||
# Invalid hex characters
|
||||
with pytest.raises(ValueError, match="Invalid character"):
|
||||
mapper.to_llama_id("llama_prompt_" + "G" * 48)
|
||||
|
||||
def test_bidirectional_conversion(self, mapper):
|
||||
"""Test bidirectional conversion preserves IDs."""
|
||||
original_id = "pmpt_0123456789abcdef" + "a" * 32
|
||||
|
||||
# Convert to MLflow name and back
|
||||
mlflow_name = mapper.to_mlflow_name(original_id)
|
||||
recovered_id = mapper.to_llama_id(mlflow_name)
|
||||
|
||||
assert recovered_id == original_id
|
||||
|
||||
def test_get_metadata_tags_with_variables(self, mapper):
|
||||
"""Test metadata tags generation with variables."""
|
||||
prompt_id = "pmpt_" + "c" * 48
|
||||
variables = ["var1", "var2", "var3"]
|
||||
|
||||
tags = mapper.get_metadata_tags(prompt_id, variables)
|
||||
|
||||
assert tags["llama_stack_id"] == prompt_id
|
||||
assert tags["llama_stack_managed"] == "true"
|
||||
assert tags["variables"] == "var1,var2,var3"
|
||||
|
||||
def test_get_metadata_tags_without_variables(self, mapper):
|
||||
"""Test metadata tags generation without variables."""
|
||||
prompt_id = "pmpt_" + "d" * 48
|
||||
|
||||
tags = mapper.get_metadata_tags(prompt_id)
|
||||
|
||||
assert tags["llama_stack_id"] == prompt_id
|
||||
assert tags["llama_stack_managed"] == "true"
|
||||
assert "variables" not in tags
|
||||
Loading…
Add table
Add a link
Reference in a new issue