chore: remove dependency on llama_models completely (#1344)

This commit is contained in:
Ashwin Bharambe 2025-03-01 12:48:08 -08:00 committed by GitHub
parent 7131d5ddeb
commit 8bbd52bb9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 131358 additions and 202 deletions

View file

@ -13,31 +13,38 @@
import importlib
from pathlib import Path
from typing import Optional
import fire
# from llama_stack.models.llama.datatypes import * # noqa: F403
from llama_models.llama3.reference_impl.generation import Llama
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.inference.meta_reference.config import MetaReferenceInferenceConfig
from llama_stack.providers.inline.inference.meta_reference.generation import Llama
THIS_DIR = Path(__file__).parent.resolve()
def run_main(
ckpt_dir: str,
model_id: str,
checkpoint_dir: str,
module_name: str,
output_path: str,
model_parallel_size: Optional[int] = None,
):
module = importlib.import_module(module_name)
assert hasattr(module, "usecases"), f"Module {module_name} missing usecases function"
tokenizer_path = str(THIS_DIR.parent / "llama3/api/tokenizer.model")
generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
config = MetaReferenceInferenceConfig(
model=model_id,
max_seq_len=512,
max_batch_size=1,
model_parallel_size=model_parallel_size,
checkpoint_dir=checkpoint_dir,
)
llama_model = resolve_model(model_id)
if not llama_model:
raise ValueError(f"Model {model_id} not found")
generator = Llama.build(
config=config,
model_id=model_id,
llama_model=llama_model,
)
use_cases = module.usecases()