From 816fdf289ac726ebebc0714adac66406c6540cb8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 3 Mar 2025 13:38:06 -0800 Subject: [PATCH] refactor: move generation.py to llama3 --- .../inline/inference/meta_reference/model_parallel.py | 5 +++-- llama_stack/providers/tests/inference/conftest.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 356cbfe7e..954da81b8 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -18,8 +18,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( CompletionRequestWithRawContent, ) +from .common import model_checkpoint_dir from .config import MetaReferenceInferenceConfig -from .generation import Llama, model_checkpoint_dir +from .llama3.generation import Llama3 from .parallel_utils import ModelParallelProcessGroup @@ -42,7 +43,7 @@ def init_model_cb( model_id: str, llama_model: Model, ): - llama = Llama.build(config, model_id, llama_model) + llama = Llama3.build(config, model_id, llama_model) return ModelRunner(llama) diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 0075ff80d..fde787ab3 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -46,7 +46,6 @@ def pytest_generate_tests(metafunc): if ("Vision" in cls_name and "Vision" in model) or ("Vision" not in cls_name and "Vision" not in model): params.append(pytest.param(model, id=model)) - print(f"params: {params}") if not params: model = metafunc.config.getoption("--inference-model") params = [pytest.param(model, id=model)]