forked from phoenix-oss/llama-stack-mirror
feat: introduce llama4 support (#1877)
As title says. Details in README, elsewhere.
This commit is contained in:
parent
23a99a4b22
commit
b8f1561956
61 changed files with 205222 additions and 6439 deletions
|
@ -4,23 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Generator
|
||||
from typing import Any, Callable, Generator
|
||||
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .llama3.generation import Llama3
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
|
@ -39,11 +33,10 @@ class ModelRunner:
|
|||
|
||||
|
||||
def init_model_cb(
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
builder_fn: Callable,
|
||||
params: list[Any],
|
||||
):
|
||||
llama = Llama3.build(config, model_id, llama_model)
|
||||
llama = builder_fn(*params)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
|
@ -60,25 +53,15 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
model_parallel_size: int,
|
||||
builder_fn: Callable,
|
||||
builder_params: list[Any],
|
||||
formatter: Llama3ChatFormat | Llama4ChatFormat,
|
||||
):
|
||||
self.config = config
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
checkpoint_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||
self.model_parallel_size = model_parallel_size
|
||||
self.builder_fn = builder_fn
|
||||
self.builder_params = builder_params
|
||||
self.formatter = formatter
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
@ -87,11 +70,9 @@ class LlamaModelParallelGenerator:
|
|||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
model_parallel_size = self.llama_model.pth_file_count
|
||||
|
||||
self.group = ModelParallelProcessGroup(
|
||||
model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
|
||||
self.model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.builder_fn, self.builder_params),
|
||||
)
|
||||
self.group.start()
|
||||
return self
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue