forked from phoenix-oss/llama-stack-mirror
refactor: move generation.py to llama3
This commit is contained in:
parent
02066591b8
commit
816fdf289a
2 changed files with 3 additions and 3 deletions
|
@ -18,8 +18,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .common import model_checkpoint_dir
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama, model_checkpoint_dir
|
from .llama3.generation import Llama3
|
||||||
from .parallel_utils import ModelParallelProcessGroup
|
from .parallel_utils import ModelParallelProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +43,7 @@ def init_model_cb(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
llama_model: Model,
|
llama_model: Model,
|
||||||
):
|
):
|
||||||
llama = Llama.build(config, model_id, llama_model)
|
llama = Llama3.build(config, model_id, llama_model)
|
||||||
return ModelRunner(llama)
|
return ModelRunner(llama)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,6 @@ def pytest_generate_tests(metafunc):
|
||||||
if ("Vision" in cls_name and "Vision" in model) or ("Vision" not in cls_name and "Vision" not in model):
|
if ("Vision" in cls_name and "Vision" in model) or ("Vision" not in cls_name and "Vision" not in model):
|
||||||
params.append(pytest.param(model, id=model))
|
params.append(pytest.param(model, id=model))
|
||||||
|
|
||||||
print(f"params: {params}")
|
|
||||||
if not params:
|
if not params:
|
||||||
model = metafunc.config.getoption("--inference-model")
|
model = metafunc.config.getoption("--inference-model")
|
||||||
params = [pytest.param(model, id=model)]
|
params = [pytest.param(model, id=model)]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue