feat: introduce llama4 support (#1877)

As title says. Details in README, elsewhere.
This commit is contained in:
Ashwin Bharambe 2025-04-05 11:53:35 -07:00 committed by GitHub
parent 23a99a4b22
commit b8f1561956
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
61 changed files with 205222 additions and 6439 deletions

View file

@ -34,11 +34,16 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import (
ModelFamily,
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
@ -55,7 +60,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import MetaReferenceInferenceConfig
from .llama3.generation import Llama3
from .generators import Llama3Generator, Llama4Generator
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
@ -64,6 +69,14 @@ log = logging.getLogger(__name__)
SEMAPHORE = asyncio.Semaphore(1)
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
return Llama3Generator(config, model_id, llama_model)
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
return Llama4Generator(config, model_id, llama_model)
class MetaReferenceInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
@ -77,29 +90,10 @@ class MetaReferenceInferenceImpl(
async def initialize(self) -> None:
pass
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
self.generator.start()
else:
self.generator = Llama3.build(self.config, model_id, llama_model)
self.model_id = model_id
self.llama_model = llama_model
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def unregister_model(self, model_id: str) -> None:
pass
@ -127,11 +121,57 @@ class MetaReferenceInferenceImpl(
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
# TODO: what is this?! you can't really specify skipping via model metadata
# kill this madness
if "skip_load" in model.metadata and model.metadata["skip_load"]:
return model
await self.load_model(model.identifier, llama_model)
return model
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if llama_model.model_family in {
ModelFamily.llama3,
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
}:
builder_fn = llama3_builder_fn
elif llama_model.model_family == ModelFamily.llama4:
builder_fn = llama4_builder_fn
else:
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
builder_params = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
model_parallel_size=llama_model.pth_file_count,
builder_fn=builder_fn,
builder_params=builder_params,
formatter=(
Llama4ChatFormat(Llama4Tokenizer.get_instance())
if llama_model.model_family == ModelFamily.llama4
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
),
)
self.generator.start()
else:
self.generator = builder_fn(*builder_params)
self.model_id = model_id
self.llama_model = llama_model
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def completion(
self,
model_id: str,
@ -164,14 +204,16 @@ class MetaReferenceInferenceImpl(
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
stop_reason = None
for token_result in self.generator.completion(request):
if token_result.text == "<|eot_id|>":
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
@ -205,6 +247,8 @@ class MetaReferenceInferenceImpl(
yield x
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
tokenizer = self.generator.formatter.tokenizer
def impl():
tokens = []
logprobs = []
@ -212,9 +256,9 @@ class MetaReferenceInferenceImpl(
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
if request.logprobs:
@ -225,11 +269,9 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
tokens = tokens[:-1]
content = self.generator.formatter.tokenizer.decode(tokens)
if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")]
elif content.endswith("<|eom_id|>"):
content = content[: -len("<|eom_id|>")]
return CompletionResponse(
content=content,
stop_reason=stop_reason,
@ -288,6 +330,8 @@ class MetaReferenceInferenceImpl(
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
tokenizer = self.generator.formatter.tokenizer
def impl():
tokens = []
logprobs = []
@ -296,9 +340,9 @@ class MetaReferenceInferenceImpl(
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
if request.logprobs:
@ -326,6 +370,8 @@ class MetaReferenceInferenceImpl(
return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -355,10 +401,10 @@ class MetaReferenceInferenceImpl(
)
continue
if token_result.text == "<|eot_id|>":
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else: