diff --git a/llama_stack/models/llama/llama4/generation.py b/llama_stack/models/llama/llama4/generation.py index 20c4e5e58..7a4087c8f 100644 --- a/llama_stack/models/llama/llama4/generation.py +++ b/llama_stack/models/llama/llama4/generation.py @@ -135,7 +135,7 @@ class Llama4: if print_model_input: cprint("Input to model:\n", "yellow") for inp in llm_inputs: - cprint(self.tokenizer.decode(inp.tokens.tolist()), "grey") + cprint(self.tokenizer.decode(inp.tokens), "grey") prompt_tokens = [inp.tokens for inp in llm_inputs] bsz = len(llm_inputs) diff --git a/llama_stack/providers/inline/inference/meta_reference/common.py b/llama_stack/providers/inline/inference/meta_reference/common.py index 3dc5e89f9..beb0d39d4 100644 --- a/llama_stack/providers/inline/inference/meta_reference/common.py +++ b/llama_stack/providers/inline/inference/meta_reference/common.py @@ -5,19 +5,10 @@ # the root directory of this source tree. from pathlib import Path -from typing import List, Optional - -from pydantic import BaseModel from llama_stack.distribution.utils.model_utils import model_local_dir -class TokenResult(BaseModel): - token: int - text: str - logprobs: Optional[List[float]] = None - - def model_checkpoint_dir(model_id) -> str: checkpoint_dir = Path(model_local_dir(model_id)) diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 7d089effc..315667506 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -51,7 +51,7 @@ class MetaReferenceInferenceConfig(BaseModel): model: str = "Llama3.2-3B-Instruct", checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}", quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}", - model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:null}", + model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}", **kwargs, ) -> Dict[str, Any]: return { diff --git a/llama_stack/providers/inline/inference/meta_reference/generators.py b/llama_stack/providers/inline/inference/meta_reference/generators.py index b820dcbd8..c2baed905 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -113,6 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent): return get_default_tool_prompt_format(request.model) +# TODO: combine Llama3 and Llama4 generators since they are almost identical now class Llama4Generator: def __init__( self, @@ -165,8 +166,8 @@ class Llama4Generator: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - llm_input=self.formatter.encode_content(request.content), + for result in self.inner_generator.generate( + llm_inputs=[self.formatter.encode_content(request.content)], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, @@ -177,7 +178,8 @@ class Llama4Generator: self.args.vocab_size, request.response_format, ), - ) + ): + yield result[0] def chat_completion( self, @@ -189,8 +191,8 @@ class Llama4Generator: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)), + for result in self.inner_generator.generate( + llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, @@ -201,7 +203,8 @@ class Llama4Generator: self.args.vocab_size, request.response_format, ), - ) + ): + yield result[0] class Llama3Generator: @@ -255,8 +258,8 @@ class Llama3Generator: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - model_input=self.formatter.encode_content(request.content), + for result in self.inner_generator.generate( + llm_inputs=[self.formatter.encode_content(request.content)], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, @@ -267,7 +270,8 @@ class Llama3Generator: self.args.vocab_size, request.response_format, ), - ) + ): + yield result[0] def chat_completion( self, @@ -279,8 +283,8 @@ class Llama3Generator: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)), + for result in self.inner_generator.generate( + llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, @@ -291,4 +295,5 @@ class Llama3Generator: self.args.vocab_size, request.response_format, ), - ) + ): + yield result[0] diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index ca2f51ac7..5f81d6421 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -149,7 +149,7 @@ class MetaReferenceInferenceImpl( if self.config.create_distributed_process_group: self.generator = LlamaModelParallelGenerator( - model_parallel_size=llama_model.pth_file_count, + model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count, builder_fn=builder_fn, builder_params=builder_params, formatter=( diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index e8767c2ff..74fc49d5e 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -32,13 +32,12 @@ from pydantic import BaseModel, Field from torch.distributed.launcher.api import LaunchConfig, elastic_launch from typing_extensions import Annotated +from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, ) -from .common import TokenResult - log = logging.getLogger(__name__) @@ -75,7 +74,7 @@ class TaskRequest(BaseModel): class TaskResponse(BaseModel): type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response - result: TokenResult + result: GenerationResult class ExceptionResponse(BaseModel): diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 8c7bcbc3c..9f97158f8 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -20,7 +20,7 @@ providers: checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} quantization: type: ${env.QUANTIZATION_TYPE:bf16} - model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null} + model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} @@ -32,7 +32,7 @@ providers: checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null} quantization: type: ${env.QUANTIZATION_TYPE:bf16} - model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null} + model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} vector_io: - provider_id: faiss provider_type: inline::faiss diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index e6c143363..eda332123 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -20,7 +20,7 @@ providers: checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} quantization: type: ${env.QUANTIZATION_TYPE:bf16} - model_parallel_size: ${env.MODEL_PARALLEL_SIZE:null} + model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {}