mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge 5c4da04f29 into 4237eb4aaa
This commit is contained in:
commit
d5836c3b5a
24 changed files with 3594 additions and 225 deletions
7
tests/unit/providers/remote/prompts/__init__.py
Normal file
7
tests/unit/providers/remote/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for remote prompts providers."""
|
||||
7
tests/unit/providers/remote/prompts/mlflow/__init__.py
Normal file
7
tests/unit/providers/remote/prompts/mlflow/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for MLflow prompts provider."""
|
||||
138
tests/unit/providers/remote/prompts/mlflow/test_config.py
Normal file
138
tests/unit/providers/remote/prompts/mlflow/test_config.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for MLflow prompts provider configuration."""
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr, ValidationError
|
||||
|
||||
from llama_stack.providers.remote.prompts.mlflow.config import (
|
||||
MLflowPromptsConfig,
|
||||
MLflowProviderDataValidator,
|
||||
)
|
||||
|
||||
|
||||
class TestMLflowPromptsConfig:
|
||||
"""Tests for MLflowPromptsConfig model."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = MLflowPromptsConfig()
|
||||
|
||||
assert config.mlflow_tracking_uri == "http://localhost:5000"
|
||||
assert config.mlflow_registry_uri is None
|
||||
assert config.experiment_name == "llama-stack-prompts"
|
||||
assert config.auth_credential is None
|
||||
assert config.timeout_seconds == 30
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration values."""
|
||||
config = MLflowPromptsConfig(
|
||||
mlflow_tracking_uri="http://mlflow.example.com:8080",
|
||||
mlflow_registry_uri="http://registry.example.com:8080",
|
||||
experiment_name="my-prompts",
|
||||
auth_credential=SecretStr("my-token"),
|
||||
timeout_seconds=60,
|
||||
)
|
||||
|
||||
assert config.mlflow_tracking_uri == "http://mlflow.example.com:8080"
|
||||
assert config.mlflow_registry_uri == "http://registry.example.com:8080"
|
||||
assert config.experiment_name == "my-prompts"
|
||||
assert config.auth_credential.get_secret_value() == "my-token"
|
||||
assert config.timeout_seconds == 60
|
||||
|
||||
def test_databricks_uri(self):
|
||||
"""Test Databricks URI configuration."""
|
||||
config = MLflowPromptsConfig(
|
||||
mlflow_tracking_uri="databricks",
|
||||
mlflow_registry_uri="databricks://profile",
|
||||
)
|
||||
|
||||
assert config.mlflow_tracking_uri == "databricks"
|
||||
assert config.mlflow_registry_uri == "databricks://profile"
|
||||
|
||||
def test_tracking_uri_validation(self):
|
||||
"""Test tracking URI validation."""
|
||||
# Empty string rejected
|
||||
with pytest.raises(ValidationError, match="mlflow_tracking_uri cannot be empty"):
|
||||
MLflowPromptsConfig(mlflow_tracking_uri="")
|
||||
|
||||
# Whitespace-only rejected
|
||||
with pytest.raises(ValidationError, match="mlflow_tracking_uri cannot be empty"):
|
||||
MLflowPromptsConfig(mlflow_tracking_uri=" ")
|
||||
|
||||
# Whitespace is stripped
|
||||
config = MLflowPromptsConfig(mlflow_tracking_uri=" http://localhost:5000 ")
|
||||
assert config.mlflow_tracking_uri == "http://localhost:5000"
|
||||
|
||||
def test_experiment_name_validation(self):
|
||||
"""Test experiment name validation."""
|
||||
# Empty string rejected
|
||||
with pytest.raises(ValidationError, match="experiment_name cannot be empty"):
|
||||
MLflowPromptsConfig(experiment_name="")
|
||||
|
||||
# Whitespace-only rejected
|
||||
with pytest.raises(ValidationError, match="experiment_name cannot be empty"):
|
||||
MLflowPromptsConfig(experiment_name=" ")
|
||||
|
||||
# Whitespace is stripped
|
||||
config = MLflowPromptsConfig(experiment_name=" my-experiment ")
|
||||
assert config.experiment_name == "my-experiment"
|
||||
|
||||
def test_timeout_validation(self):
|
||||
"""Test timeout range validation."""
|
||||
# Too low rejected
|
||||
with pytest.raises(ValidationError):
|
||||
MLflowPromptsConfig(timeout_seconds=0)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
MLflowPromptsConfig(timeout_seconds=-1)
|
||||
|
||||
# Too high rejected
|
||||
with pytest.raises(ValidationError):
|
||||
MLflowPromptsConfig(timeout_seconds=301)
|
||||
|
||||
# Boundary values accepted
|
||||
config_min = MLflowPromptsConfig(timeout_seconds=1)
|
||||
assert config_min.timeout_seconds == 1
|
||||
|
||||
config_max = MLflowPromptsConfig(timeout_seconds=300)
|
||||
assert config_max.timeout_seconds == 300
|
||||
|
||||
def test_sample_run_config(self):
|
||||
"""Test sample_run_config generates valid configuration."""
|
||||
# Default environment variable
|
||||
sample = MLflowPromptsConfig.sample_run_config()
|
||||
assert sample["mlflow_tracking_uri"] == "http://localhost:5000"
|
||||
assert sample["experiment_name"] == "llama-stack-prompts"
|
||||
assert sample["auth_credential"] == "${env.MLFLOW_TRACKING_TOKEN:=}"
|
||||
|
||||
# Custom values
|
||||
sample = MLflowPromptsConfig.sample_run_config(
|
||||
mlflow_api_token="test-token",
|
||||
mlflow_tracking_uri="http://custom:5000",
|
||||
)
|
||||
assert sample["mlflow_tracking_uri"] == "http://custom:5000"
|
||||
assert sample["auth_credential"] == "test-token"
|
||||
|
||||
|
||||
class TestMLflowProviderDataValidator:
|
||||
"""Tests for MLflowProviderDataValidator."""
|
||||
|
||||
def test_provider_data_validator(self):
|
||||
"""Test provider data validator with and without token."""
|
||||
# With token
|
||||
validator = MLflowProviderDataValidator(mlflow_api_token="test-token-123")
|
||||
assert validator.mlflow_api_token == "test-token-123"
|
||||
|
||||
# Without token
|
||||
validator = MLflowProviderDataValidator()
|
||||
assert validator.mlflow_api_token is None
|
||||
|
||||
# From dictionary
|
||||
data = {"mlflow_api_token": "secret-token"}
|
||||
validator = MLflowProviderDataValidator(**data)
|
||||
assert validator.mlflow_api_token == "secret-token"
|
||||
95
tests/unit/providers/remote/prompts/mlflow/test_mapping.py
Normal file
95
tests/unit/providers/remote/prompts/mlflow/test_mapping.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for MLflow prompts ID mapping utilities."""
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.remote.prompts.mlflow.mapping import PromptIDMapper
|
||||
|
||||
|
||||
class TestPromptIDMapper:
|
||||
"""Tests for PromptIDMapper class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mapper(self):
|
||||
"""Create ID mapper instance."""
|
||||
return PromptIDMapper()
|
||||
|
||||
def test_to_mlflow_name_valid(self, mapper):
|
||||
"""Test converting valid prompt_id to MLflow name."""
|
||||
prompt_id = "pmpt_" + "a" * 48
|
||||
mlflow_name = mapper.to_mlflow_name(prompt_id)
|
||||
|
||||
assert mlflow_name == "llama_prompt_" + "a" * 48
|
||||
assert mlflow_name.startswith(mapper.MLFLOW_NAME_PREFIX)
|
||||
|
||||
def test_to_mlflow_name_invalid(self, mapper):
|
||||
"""Test conversion fails with invalid inputs."""
|
||||
# Invalid prefix
|
||||
with pytest.raises(ValueError, match="Invalid prompt_id format"):
|
||||
mapper.to_mlflow_name("invalid_" + "a" * 48)
|
||||
|
||||
# Wrong length
|
||||
with pytest.raises(ValueError, match="Invalid prompt_id format"):
|
||||
mapper.to_mlflow_name("pmpt_" + "a" * 47)
|
||||
|
||||
# Invalid hex characters
|
||||
with pytest.raises(ValueError, match="Invalid prompt_id format"):
|
||||
mapper.to_mlflow_name("pmpt_" + "g" * 48)
|
||||
|
||||
def test_to_llama_id_valid(self, mapper):
|
||||
"""Test converting valid MLflow name to prompt_id."""
|
||||
mlflow_name = "llama_prompt_" + "b" * 48
|
||||
prompt_id = mapper.to_llama_id(mlflow_name)
|
||||
|
||||
assert prompt_id == "pmpt_" + "b" * 48
|
||||
assert prompt_id.startswith("pmpt_")
|
||||
|
||||
def test_to_llama_id_invalid(self, mapper):
|
||||
"""Test conversion fails with invalid inputs."""
|
||||
# Invalid prefix
|
||||
with pytest.raises(ValueError, match="does not start with expected prefix"):
|
||||
mapper.to_llama_id("wrong_prefix_" + "a" * 48)
|
||||
|
||||
# Wrong length
|
||||
with pytest.raises(ValueError, match="Invalid hex part length"):
|
||||
mapper.to_llama_id("llama_prompt_" + "a" * 47)
|
||||
|
||||
# Invalid hex characters
|
||||
with pytest.raises(ValueError, match="Invalid character"):
|
||||
mapper.to_llama_id("llama_prompt_" + "G" * 48)
|
||||
|
||||
def test_bidirectional_conversion(self, mapper):
|
||||
"""Test bidirectional conversion preserves IDs."""
|
||||
original_id = "pmpt_0123456789abcdef" + "a" * 32
|
||||
|
||||
# Convert to MLflow name and back
|
||||
mlflow_name = mapper.to_mlflow_name(original_id)
|
||||
recovered_id = mapper.to_llama_id(mlflow_name)
|
||||
|
||||
assert recovered_id == original_id
|
||||
|
||||
def test_get_metadata_tags_with_variables(self, mapper):
|
||||
"""Test metadata tags generation with variables."""
|
||||
prompt_id = "pmpt_" + "c" * 48
|
||||
variables = ["var1", "var2", "var3"]
|
||||
|
||||
tags = mapper.get_metadata_tags(prompt_id, variables)
|
||||
|
||||
assert tags["llama_stack_id"] == prompt_id
|
||||
assert tags["llama_stack_managed"] == "true"
|
||||
assert tags["variables"] == "var1,var2,var3"
|
||||
|
||||
def test_get_metadata_tags_without_variables(self, mapper):
|
||||
"""Test metadata tags generation without variables."""
|
||||
prompt_id = "pmpt_" + "d" * 48
|
||||
|
||||
tags = mapper.get_metadata_tags(prompt_id)
|
||||
|
||||
assert tags["llama_stack_id"] == prompt_id
|
||||
assert tags["llama_stack_managed"] == "true"
|
||||
assert "variables" not in tags
|
||||
Loading…
Add table
Add a link
Reference in a new issue