Add unit tests for CLI to increase code coverage + fix bug

When running `llama stack run`, using `config` instead of `--config` causes an error. This is now fixed and tested with unit tests.

Unit test coverage has been added to the basic CLI argparsing.

Signed-off-by: Courtney Pacheco <6019922+courtneypacheco@users.noreply.github.com>
This commit is contained in:
Courtney Pacheco 2025-03-10 21:15:02 -04:00
parent ff853ccc38
commit 6beb80f2ac
3 changed files with 706 additions and 2 deletions

View file

@ -30,7 +30,7 @@ class StackRun(Subcommand):
def _add_arguments(self):
self.parser.add_argument(
"config",
"--config",
type=str,
help="Path to config file to use for the run",
)

View file

@ -0,0 +1,700 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import patch
import pytest
from llama_stack.cli.llama import (
LlamaCLIParser,
main,
)
# Note: Running the unit tests through 'uv' confuses pytest. pytest will use uv's "sys.argv" values, not the ones we intend to test.
class TestMain:
@patch("sys.argv", ("llama", "stack", "list-providers"))
def test_run(self):
main()
# Note: Running the unit tests through 'uv' confuses pytest. pytest will use uv's "sys.argv" values, not the ones we intend to test.
class TestLlamaCLIParser:
@patch("sys.argv", ("llama", "fake-choice"))
def test_invalid_choice(self):
with pytest.raises(SystemExit):
parser = LlamaCLIParser()
parser.parse_args()
@patch("sys.argv", ("llama", "stack", "list-providers"))
def test_run(self):
parser = LlamaCLIParser()
args = parser.parse_args()
parser.run(args)
# Note: Running the unit tests through 'uv' confuses pytest. pytest will use uv's "sys.argv" values, not the ones we intend to test.
class TestDownloadSubcommand:
@pytest.mark.parametrize(
"test_args, expected",
[
( # Download model from meta, using default number of parallel downloads
(
"llama",
"download",
"--source",
"meta",
"--model-id",
"Llama3.2-1B",
"--meta-url",
"url-obtained-from-llama.meta.com",
),
{
"source": "meta",
"model_id": "Llama3.2-1B",
"meta_url": "url-obtained-from-llama.meta.com",
"max_parallel": 3,
},
),
( # Download model from meta, setting --max-parallel=10
(
"llama",
"download",
"--source",
"meta",
"--model-id",
"Llama3.2-1B",
"--meta-url",
"url-obtained-from-llama.meta.com",
"--max-parallel",
"10",
),
{
"source": "meta",
"model_id": "Llama3.2-1B",
"meta_url": "url-obtained-from-llama.meta.com",
"max_parallel": 10,
},
),
( # Download two models from meta
(
"llama",
"download",
"--source",
"meta",
"--model-id",
"Llama3.2-1B,Llama3.2-3B",
"--meta-url",
"url-obtained-from-llama.meta.com",
),
{
"source": "meta",
"model_id": "Llama3.2-1B,Llama3.2-3B",
"meta_url": "url-obtained-from-llama.meta.com",
"max_parallel": 3,
},
),
],
)
def test_download_from_meta_url(self, test_args, expected):
with patch("sys.argv", test_args):
parser = LlamaCLIParser()
args = parser.parse_args()
assert args.source == expected["source"]
assert args.model_id == expected["model_id"]
assert args.meta_url == expected["meta_url"]
assert args.max_parallel == expected["max_parallel"]
@pytest.mark.parametrize(
"test_args, expected",
[
( # Download model from Hugging Face, using default number of parallel downloads
(
"llama",
"download",
"--source",
"huggingface",
"--model-id",
"Llama3.2-1B",
"--hf-token",
"fake-hf-token",
"--ignore-patterns",
"*.something",
),
{
"source": "huggingface",
"model_id": "Llama3.2-1B",
"hf_token": "fake-hf-token",
"ignore_patterns": "*.something",
"max_parallel": 3,
},
),
( # Download model from Hugging Face, setting --max-parallel=10
(
"llama",
"download",
"--source",
"huggingface",
"--model-id",
"Llama3.2-1B",
"--hf-token",
"fake-hf-token",
"--ignore-patterns",
"*.something",
"--max-parallel",
"10",
),
{
"source": "huggingface",
"model_id": "Llama3.2-1B",
"hf_token": "fake-hf-token",
"ignore_patterns": "*.something",
"max_parallel": 10,
},
),
],
)
def test_download_from_hugging_face(self, test_args, expected):
with patch("sys.argv", test_args):
parser = LlamaCLIParser()
args = parser.parse_args()
assert args.source == expected["source"]
assert args.model_id == expected["model_id"]
assert args.hf_token == expected["hf_token"]
assert args.max_parallel == expected["max_parallel"]
assert args.ignore_patterns == expected["ignore_patterns"]
class TestVerifyDownloadSubcommand:
def test_verify_downloaded_model(self):
test_args = ("llama", "verify-download", "--model-id", "Llama3.2-1B")
expected_model_id = "Llama3.2-1B"
with patch("sys.argv", test_args):
parser = LlamaCLIParser()
args = parser.parse_args()
assert args.model_id == expected_model_id
# Note: Running the unit tests through 'uv' confuses pytest. pytest will use uv's "sys.argv" values, not the ones we intend to test.
class TestVerifyStackSubcommand:
@pytest.mark.parametrize(
"test_args, expected",
[
( # Verify stack with --image_type=container
(
"llama",
"stack",
"build",
"--config",
"./some/path/to/config",
"--template",
"some-template",
"--image-type",
"container",
),
{
"config": "./some/path/to/config",
"template": "some-template",
"image_type": "container",
"image_name": None,
},
),
( # Verify stack with --image_type=conda
(
"llama",
"stack",
"build",
"--config",
"./some/path/to/config",
"--template",
"some-template",
"--image-type",
"conda",
"--image-name",
"name-of-conda-env",
),
{
"config": "./some/path/to/config",
"template": "some-template",
"image_type": "conda",
"image_name": "name-of-conda-env",
},
),
( # Verify stack with --image_type=venv
(
"llama",
"stack",
"build",
"--config",
"./some/path/to/config",
"--template",
"some-template",
"--image-type",
"venv",
"--image-name",
"name-of-venv",
),
{
"config": "./some/path/to/config",
"template": "some-template",
"image_type": "venv",
"image_name": "name-of-venv",
},
),
],
)
def test_stack_build_subcommand(self, test_args, expected):
with patch("sys.argv", test_args):
parser = LlamaCLIParser()
args = parser.parse_args()
assert args.config == expected["config"]
assert args.template == expected["template"]
assert args.image_type == expected["image_type"]
assert args.image_name == expected["image_name"]
@patch("sys.argv", ("llama", "stack", "list-apis"))
def test_list_apis_subcommand(self):
parser = LlamaCLIParser()
parser.parse_args()
@patch("sys.argv", ("llama", "stack", "list-providers"))
def test_list_providers_subcommand(self):
parser = LlamaCLIParser()
parser.parse_args()
@pytest.mark.parametrize(
"test_args, expected",
[
( # conda: Using default values
(
"llama",
"stack",
"run",
"--config",
"./some/path/to/config",
"--image-type",
"conda",
"--image-name",
"some-image",
"--env",
"SOME_KEY=VALUE",
"--tls-certfile",
"tlscert",
"--tls-keyfile",
"tlskeyfile",
),
{
"config": "./some/path/to/config",
"port": 8321,
"env": ["SOME_KEY=VALUE"],
"image_name": "some-image",
"disable_ipv6": False,
"image_type": "conda",
"tls_certfile": "tlscert",
"tls_keyfile": "tlskeyfile",
},
),
( # conda: Using custom inputs to override defaults (like: port, env, etc.)
(
"llama",
"stack",
"run",
"--config",
"./some/path/to/config",
"--image-type",
"conda",
"--port",
"9999",
"--env",
"SOME_KEY=VALUE",
"--disable-ipv6",
"--image-name",
"some-image",
"--tls-certfile",
"tlscert",
"--tls-keyfile",
"tlskeyfile",
),
{
"config": "./some/path/to/config",
"port": 9999,
"env": ["SOME_KEY=VALUE"],
"image_name": "some-image",
"disable_ipv6": True,
"image_type": "conda",
"tls_certfile": "tlscert",
"tls_keyfile": "tlskeyfile",
},
),
( # container: Using default values
(
"llama",
"stack",
"run",
"--config",
"./some/path/to/config",
"--image-type",
"container",
"--image-name",
"some-image",
"--env",
"SOME_KEY=VALUE",
"--tls-certfile",
"tlscert",
"--tls-keyfile",
"tlskeyfile",
),
{
"config": "./some/path/to/config",
"port": 8321,
"env": ["SOME_KEY=VALUE"],
"image_name": "some-image",
"disable_ipv6": False,
"image_type": "container",
"tls_certfile": "tlscert",
"tls_keyfile": "tlskeyfile",
},
),
( # container: Using custom inputs to override defaults (like: port, env, etc.)
(
"llama",
"stack",
"run",
"--config",
"./some/path/to/config",
"--image-type",
"container",
"--port",
"9999",
"--env",
"SOME_KEY=VALUE",
"--disable-ipv6",
"--image-name",
"some-image",
"--tls-certfile",
"tlscert",
"--tls-keyfile",
"tlskeyfile",
),
{
"config": "./some/path/to/config",
"port": 9999,
"env": ["SOME_KEY=VALUE"],
"image_name": "some-image",
"disable_ipv6": True,
"image_type": "container",
"tls_certfile": "tlscert",
"tls_keyfile": "tlskeyfile",
},
),
( # venv: Using default values
(
"llama",
"stack",
"run",
"--config",
"./some/path/to/config",
"--image-type",
"venv",
"--image-name",
"some-image",
"--env",
"SOME_KEY=VALUE",
"--tls-certfile",
"tlscert",
"--tls-keyfile",
"tlskeyfile",
),
{
"config": "./some/path/to/config",
"port": 8321,
"env": ["SOME_KEY=VALUE"],
"image_name": "some-image",
"disable_ipv6": False,
"image_type": "venv",
"tls_certfile": "tlscert",
"tls_keyfile": "tlskeyfile",
},
),
( # venv: Using custom inputs to override defaults (like: port, env, etc.)
(
"llama",
"stack",
"run",
"--config",
"./some/path/to/config",
"--image-type",
"venv",
"--port",
"9999",
"--env",
"SOME_KEY=VALUE",
"--disable-ipv6",
"--image-name",
"some-image",
"--tls-certfile",
"tlscert",
"--tls-keyfile",
"tlskeyfile",
),
{
"config": "./some/path/to/config",
"port": 9999,
"env": ["SOME_KEY=VALUE"],
"image_name": "some-image",
"disable_ipv6": True,
"image_type": "venv",
"tls_certfile": "tlscert",
"tls_keyfile": "tlskeyfile",
},
),
],
)
def test_run_subcommand(self, test_args, expected):
with patch("sys.argv", test_args):
parser = LlamaCLIParser()
args = parser.parse_args()
assert args.config == expected["config"]
assert args.port == expected["port"]
assert args.env == expected["env"]
assert args.image_type == expected["image_type"]
assert args.image_name == expected["image_name"]
assert args.disable_ipv6 == expected["disable_ipv6"]
assert args.tls_certfile == expected["tls_certfile"]
assert args.tls_keyfile == expected["tls_keyfile"]
# Note: Running the unit tests through 'uv' confuses pytest. pytest will use uv's "sys.argv" values, not the ones we intend to test.
class TestModelSubcommand:
def test_model_download_subcommand(self):
test_args = ("llama", "model", "download")
with patch("sys.argv", test_args):
parser = LlamaCLIParser()
parser.parse_args()
@pytest.mark.parametrize(
"test_args, expected",
[
( # Show all models
(
"llama",
"model",
"list",
"--show-all",
),
{
"show_all": True,
"downloaded": False,
"search": None,
},
),
( # List downloaded models
(
"llama",
"model",
"list",
"--downloaded",
),
{
"show_all": False,
"downloaded": True,
"search": None,
},
),
( # Search downloaded models
(
"llama",
"model",
"list",
"--search",
"some-search-str",
),
{
"show_all": False,
"downloaded": False,
"search": "some-search-str",
},
),
],
)
def test_model_list_subcommand(self, test_args, expected):
with patch("sys.argv", test_args):
parser = LlamaCLIParser()
args = parser.parse_args()
assert args.show_all == expected["show_all"]
assert args.downloaded == expected["downloaded"]
assert args.search == expected["search"]
@pytest.mark.parametrize(
"test_args, expected",
[
( # List all available models
(
"llama",
"model",
"prompt-format",
"--list",
),
{
"list": True,
"model_name": "llama3_1",
},
),
( # Describe model prompt format for a model
(
"llama",
"model",
"prompt-format",
"--model-name",
"llamaX_Y",
),
{
"list": False,
"model_name": "llamaX_Y",
},
),
( # Describe model prompt format for a model
(
"llama",
"model",
"prompt-format",
"-m",
"llamaX_Y",
),
{
"list": False,
"model_name": "llamaX_Y",
},
),
],
)
def test_model_prompt_format_subcommand(self, test_args, expected):
with patch("sys.argv", test_args):
parser = LlamaCLIParser()
args = parser.parse_args()
assert args.list == expected["list"]
assert args.model_name == expected["model_name"]
@pytest.mark.parametrize(
"test_args, expected",
[
( # Describe existing model
(
"llama",
"model",
"describe",
"--model-id",
"llamaX_Y",
),
{
"model_id": "llamaX_Y",
},
),
( # Call subcommand while omitting "--model-id". Should raise an exception because it's a required arg.
(
"llama",
"model",
"describe",
),
SystemExit,
),
],
)
def test_model_describe_subcommand(self, test_args, expected):
with patch("sys.argv", test_args):
if type(expected) is type and issubclass(expected, BaseException):
with pytest.raises(expected):
parser = LlamaCLIParser()
args = parser.parse_args()
else:
parser = LlamaCLIParser()
args = parser.parse_args()
assert args.model_id == expected["model_id"]
@pytest.mark.parametrize(
"test_args, expected",
[
( # Describe existing model
(
"llama",
"model",
"verify-download",
"--model-id",
"llamaX_Y",
),
None,
),
( # Call subcommand while omitting "--model-id". Should raise an exception because it's a required arg.
(
"llama",
"model",
"verify-download",
),
SystemExit,
),
],
)
def test_model_verify_download_subcommand(self, test_args, expected):
with patch("sys.argv", test_args):
if type(expected) is type and issubclass(expected, BaseException):
with pytest.raises(expected):
parser = LlamaCLIParser()
parser.parse_args()
else:
parser = LlamaCLIParser()
parser.parse_args()
@pytest.mark.parametrize(
"test_args, expected",
[
( # Remove an existing model
(
"llama",
"model",
"remove",
"--model",
"llamaX_Y",
),
{
"model": "llamaX_Y",
"force": False,
},
),
( # Forcefully remove an existing model
(
"llama",
"model",
"remove",
"--model",
"llamaX_Y",
"--force",
),
{
"model": "llamaX_Y",
"force": True,
},
),
( # Call subcommand while omitting "--model". Should raise an exception because it's a required arg.
(
"llama",
"model",
"remove",
),
SystemExit,
),
],
)
def test_model_remove_subcommand(self, test_args, expected):
with patch("sys.argv", test_args):
if type(expected) is type and issubclass(expected, BaseException):
with pytest.raises(expected):
parser = LlamaCLIParser()
parser.parse_args()
else:
parser = LlamaCLIParser()
args = parser.parse_args()
assert args.model == expected["model"]
assert args.force == expected["force"]

View file

@ -230,5 +230,9 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
# records from our chat completion call. A message gets logged
# here any time we exceed the slow_callback_duration configured
# above.
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
asyncio_warnings = [
record.message
for record in caplog.records
if record.name == "asyncio" and "Executing <Task" not in record.message
]
assert not asyncio_warnings