forked from phoenix-oss/llama-stack-mirror
feat: introduce llama4 support (#1877)
As title says. Details in README, elsewhere.
This commit is contained in:
parent
23a99a4b22
commit
b8f1561956
61 changed files with 205222 additions and 6439 deletions
|
@ -34,11 +34,16 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
ModelFamily,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
|
@ -55,7 +60,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .llama3.generation import Llama3
|
||||
from .generators import Llama3Generator, Llama4Generator
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -64,6 +69,14 @@ log = logging.getLogger(__name__)
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
|
||||
return Llama3Generator(config, model_id, llama_model)
|
||||
|
||||
|
||||
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
|
||||
return Llama4Generator(config, model_id, llama_model)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
|
@ -77,29 +90,10 @@ class MetaReferenceInferenceImpl(
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama3.build(self.config, model_id, llama_model)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
@ -127,11 +121,57 @@ class MetaReferenceInferenceImpl(
|
|||
if model.model_type == ModelType.embedding:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
||||
# TODO: what is this?! you can't really specify skipping via model metadata
|
||||
# kill this madness
|
||||
if "skip_load" in model.metadata and model.metadata["skip_load"]:
|
||||
return model
|
||||
|
||||
await self.load_model(model.identifier, llama_model)
|
||||
return model
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
|
||||
if llama_model.model_family in {
|
||||
ModelFamily.llama3,
|
||||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
}:
|
||||
builder_fn = llama3_builder_fn
|
||||
elif llama_model.model_family == ModelFamily.llama4:
|
||||
builder_fn = llama4_builder_fn
|
||||
else:
|
||||
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
|
||||
|
||||
builder_params = [self.config, model_id, llama_model]
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
model_parallel_size=llama_model.pth_file_count,
|
||||
builder_fn=builder_fn,
|
||||
builder_params=builder_params,
|
||||
formatter=(
|
||||
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
||||
if llama_model.model_family == ModelFamily.llama4
|
||||
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
|
||||
),
|
||||
)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = builder_fn(*builder_params)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -164,14 +204,16 @@ class MetaReferenceInferenceImpl(
|
|||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.completion(request):
|
||||
if token_result.text == "<|eot_id|>":
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
|
@ -205,6 +247,8 @@ class MetaReferenceInferenceImpl(
|
|||
yield x
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
@ -212,9 +256,9 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
if token_result.text == "<|eot_id|>":
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
|
@ -225,11 +269,9 @@ class MetaReferenceInferenceImpl(
|
|||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
||||
tokens = tokens[:-1]
|
||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||
if content.endswith("<|eot_id|>"):
|
||||
content = content[: -len("<|eot_id|>")]
|
||||
elif content.endswith("<|eom_id|>"):
|
||||
content = content[: -len("<|eom_id|>")]
|
||||
return CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
|
@ -288,6 +330,8 @@ class MetaReferenceInferenceImpl(
|
|||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
@ -296,9 +340,9 @@ class MetaReferenceInferenceImpl(
|
|||
for token_result in self.generator.chat_completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
|
@ -326,6 +370,8 @@ class MetaReferenceInferenceImpl(
|
|||
return impl()
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -355,10 +401,10 @@ class MetaReferenceInferenceImpl(
|
|||
)
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue