refactor: move generation.py to llama3

This commit is contained in:
Ashwin Bharambe 2025-03-03 13:38:06 -08:00
parent 02066591b8
commit 816fdf289a
2 changed files with 3 additions and 3 deletions

View file

@ -18,8 +18,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
) )
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir from .llama3.generation import Llama3
from .parallel_utils import ModelParallelProcessGroup from .parallel_utils import ModelParallelProcessGroup
@ -42,7 +43,7 @@ def init_model_cb(
model_id: str, model_id: str,
llama_model: Model, llama_model: Model,
): ):
llama = Llama.build(config, model_id, llama_model) llama = Llama3.build(config, model_id, llama_model)
return ModelRunner(llama) return ModelRunner(llama)

View file

@ -46,7 +46,6 @@ def pytest_generate_tests(metafunc):
if ("Vision" in cls_name and "Vision" in model) or ("Vision" not in cls_name and "Vision" not in model): if ("Vision" in cls_name and "Vision" in model) or ("Vision" not in cls_name and "Vision" not in model):
params.append(pytest.param(model, id=model)) params.append(pytest.param(model, id=model))
print(f"params: {params}")
if not params: if not params:
model = metafunc.config.getoption("--inference-model") model = metafunc.config.getoption("--inference-model")
params = [pytest.param(model, id=model)] params = [pytest.param(model, id=model)]