mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Redo the { models, shields, memory_banks } typeset
This commit is contained in:
parent
6b094b72d3
commit
f3923e3f0b
15 changed files with 588 additions and 454 deletions
|
|
@ -129,7 +129,10 @@ class StackConfigure(Subcommand):
|
|||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.configure import configure_api_providers
|
||||
from llama_stack.distribution.configure import (
|
||||
configure_api_providers,
|
||||
parse_and_maybe_upgrade_config,
|
||||
)
|
||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||
|
||||
builds_dir = BUILDS_BASE_DIR / build_config.image_type
|
||||
|
|
@ -145,7 +148,8 @@ class StackConfigure(Subcommand):
|
|||
"yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
config = StackRunConfig(**yaml.safe_load(run_config_file.read_text()))
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
config = StackRunConfig(
|
||||
built_at=datetime.now(),
|
||||
|
|
|
|||
|
|
@ -1,105 +1,142 @@
|
|||
from argparse import Namespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from llama_stack.distribution.datatypes import BuildConfig
|
||||
from llama_stack.cli.stack.build import StackBuild
|
||||
|
||||
|
||||
# temporary while we make the tests work
|
||||
pytest.skip(allow_module_level=True)
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
from llama_stack.distribution.configure import (
|
||||
parse_and_maybe_upgrade_config,
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stack_build():
|
||||
parser = MagicMock()
|
||||
subparsers = MagicMock()
|
||||
return StackBuild(subparsers)
|
||||
|
||||
|
||||
def test_stack_build_initialization(stack_build):
|
||||
assert stack_build.parser is not None
|
||||
assert stack_build.parser.set_defaults.called_once_with(
|
||||
func=stack_build._run_stack_build_command
|
||||
def up_to_date_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
version: {version}
|
||||
image_name: foo
|
||||
apis_to_serve: []
|
||||
built_at: {built_at}
|
||||
models:
|
||||
- identifier: model1
|
||||
provider_id: provider1
|
||||
llama_model: Llama3.1-8B-Instruct
|
||||
shields:
|
||||
- identifier: shield1
|
||||
type: llama_guard
|
||||
provider_id: provider1
|
||||
memory_banks:
|
||||
- identifier: memory1
|
||||
type: vector
|
||||
provider_id: provider1
|
||||
embedding_model: all-MiniLM-L6-v2
|
||||
chunk_size_in_tokens: 512
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: provider1
|
||||
provider_type: meta-reference
|
||||
config: {{}}
|
||||
safety:
|
||||
- provider_id: provider1
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- provider_id: provider1
|
||||
provider_type: meta-reference
|
||||
config: {{}}
|
||||
""".format(
|
||||
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.build.build_image")
|
||||
def test_run_stack_build_command_with_config(
|
||||
mock_build_image, mock_build_config, stack_build
|
||||
):
|
||||
args = Namespace(
|
||||
config="test_config.yaml",
|
||||
template=None,
|
||||
list_templates=False,
|
||||
name=None,
|
||||
image_type="conda",
|
||||
@pytest.fixture
|
||||
def old_config():
|
||||
return yaml.safe_load(
|
||||
"""
|
||||
image_name: foo
|
||||
built_at: {built_at}
|
||||
apis_to_serve: []
|
||||
routing_table:
|
||||
inference:
|
||||
- provider_type: remote::ollama
|
||||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
routing_key: Llama3.2-1B-Instruct
|
||||
- provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.1-8B-Instruct
|
||||
routing_key: Llama3.1-8B-Instruct
|
||||
safety:
|
||||
- routing_key: ["shield1", "shield2"]
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
enable_prompt_guard: false
|
||||
memory:
|
||||
- routing_key: vector
|
||||
provider_type: meta-reference
|
||||
config: {{}}
|
||||
api_providers:
|
||||
telemetry:
|
||||
provider_type: noop
|
||||
config: {{}}
|
||||
""".format(built_at=datetime.now().isoformat())
|
||||
)
|
||||
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch("yaml.safe_load") as mock_yaml_load:
|
||||
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
|
||||
mock_build_config.return_value = MagicMock()
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
mock_build_config.assert_called_once()
|
||||
mock_build_image.assert_called_once()
|
||||
@pytest.fixture
|
||||
def invalid_config():
|
||||
return yaml.safe_load("""
|
||||
routing_table: {}
|
||||
api_providers: {}
|
||||
""")
|
||||
|
||||
|
||||
@patch("llama_stack.cli.table.print_table")
|
||||
def test_run_stack_build_command_list_templates(mock_print_table, stack_build):
|
||||
args = Namespace(list_templates=True)
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
mock_print_table.assert_called_once()
|
||||
def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
|
||||
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert len(result.models) == 1
|
||||
assert len(result.shields) == 1
|
||||
assert len(result.memory_banks) == 1
|
||||
assert "inference" in result.providers
|
||||
|
||||
|
||||
@patch("prompt_toolkit.prompt")
|
||||
@patch("llama_stack.distribution.datatypes.BuildConfig")
|
||||
@patch("llama_stack.distribution.build.build_image")
|
||||
def test_run_stack_build_command_interactive(
|
||||
mock_build_image, mock_build_config, mock_prompt, stack_build
|
||||
):
|
||||
args = Namespace(
|
||||
config=None, template=None, list_templates=False, name=None, image_type=None
|
||||
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||
result = parse_and_maybe_upgrade_config(old_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert len(result.models) == 2
|
||||
assert len(result.shields) == 2
|
||||
assert len(result.memory_banks) == 1
|
||||
assert all(
|
||||
api in result.providers
|
||||
for api in ["inference", "safety", "memory", "telemetry"]
|
||||
)
|
||||
safety_provider = result.providers["safety"][0]
|
||||
assert safety_provider.provider_type == "meta-reference"
|
||||
assert "llama_guard_shield" in safety_provider.config
|
||||
|
||||
mock_prompt.side_effect = [
|
||||
"test_name",
|
||||
"conda",
|
||||
"meta-reference",
|
||||
"test description",
|
||||
]
|
||||
mock_build_config.return_value = MagicMock()
|
||||
inference_providers = result.providers["inference"]
|
||||
assert len(inference_providers) == 2
|
||||
assert set(x.provider_id for x in inference_providers) == {
|
||||
"remote::ollama-00",
|
||||
"meta-reference-01",
|
||||
}
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
assert mock_prompt.call_count == 4
|
||||
mock_build_config.assert_called_once()
|
||||
mock_build_image.assert_called_once()
|
||||
ollama = inference_providers[0]
|
||||
assert ollama.provider_type == "remote::ollama"
|
||||
assert ollama.config["port"] == 11434
|
||||
|
||||
|
||||
@patch("llama_stack.distribution.datatypes.BuildConfig")
|
||||
@patch("llama_stack.distribution.build.build_image")
|
||||
def test_run_stack_build_command_with_template(
|
||||
mock_build_image, mock_build_config, stack_build
|
||||
):
|
||||
args = Namespace(
|
||||
config=None,
|
||||
template="test_template",
|
||||
list_templates=False,
|
||||
name="test_name",
|
||||
image_type="docker",
|
||||
)
|
||||
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch("yaml.safe_load") as mock_yaml_load:
|
||||
mock_yaml_load.return_value = {"name": "test_build", "image_type": "conda"}
|
||||
mock_build_config.return_value = MagicMock()
|
||||
|
||||
stack_build._run_stack_build_command(args)
|
||||
|
||||
mock_build_config.assert_called_once()
|
||||
mock_build_image.assert_called_once()
|
||||
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
|
||||
with pytest.raises(ValueError):
|
||||
parse_and_maybe_upgrade_config(invalid_config)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue