feat: introduce llama4 support (#1877)

As title says. Details in README, elsewhere.
This commit is contained in:
Ashwin Bharambe 2025-04-05 11:53:35 -07:00 committed by GitHub
parent 23a99a4b22
commit b8f1561956
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
61 changed files with 205222 additions and 6439 deletions

View file

@ -4,23 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from copy import deepcopy
from functools import partial
from typing import Any, Generator
from typing import Any, Callable, Generator
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig
from .llama3.generation import Llama3
from .parallel_utils import ModelParallelProcessGroup
@ -39,11 +33,10 @@ class ModelRunner:
def init_model_cb(
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
builder_fn: Callable,
params: list[Any],
):
llama = Llama3.build(config, model_id, llama_model)
llama = builder_fn(*params)
return ModelRunner(llama)
@ -60,25 +53,15 @@ class LlamaModelParallelGenerator:
def __init__(
self,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
model_parallel_size: int,
builder_fn: Callable,
builder_params: list[Any],
formatter: Llama3ChatFormat | Llama4ChatFormat,
):
self.config = config
self.model_id = model_id
self.llama_model = llama_model
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
checkpoint_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor())
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
self.model_parallel_size = model_parallel_size
self.builder_fn = builder_fn
self.builder_params = builder_params
self.formatter = formatter
def start(self):
self.__enter__()
@ -87,11 +70,9 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None)
def __enter__(self):
model_parallel_size = self.llama_model.pth_file_count
self.group = ModelParallelProcessGroup(
model_parallel_size,
init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
self.model_parallel_size,
init_model_cb=partial(init_model_cb, self.builder_fn, self.builder_params),
)
self.group.start()
return self