diff --git a/llama_stack/providers/inline/inference/meta_reference/common.py b/llama_stack/providers/inline/inference/meta_reference/common.py new file mode 100644 index 000000000..3dc5e89f9 --- /dev/null +++ b/llama_stack/providers/inline/inference/meta_reference/common.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path +from typing import List, Optional + +from pydantic import BaseModel + +from llama_stack.distribution.utils.model_utils import model_local_dir + + +class TokenResult(BaseModel): + token: int + text: str + logprobs: Optional[List[float]] = None + + +def model_checkpoint_dir(model_id) -> str: + checkpoint_dir = Path(model_local_dir(model_id)) + + paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]] + if not any(p.exists() for p in paths): + checkpoint_dir = checkpoint_dir / "original" + + assert checkpoint_dir.exists(), ( + f"Could not find checkpoints in: {model_local_dir(model_id)}. " + f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`" + f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}" + ) + return str(checkpoint_dir) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index cd7bcdd22..062bf215e 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -55,7 +55,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from .config import MetaReferenceInferenceConfig -from .generation import Llama +from .llama3.generation import Llama3 from .model_parallel import LlamaModelParallelGenerator log = logging.getLogger(__name__) @@ -83,7 +83,7 @@ class MetaReferenceInferenceImpl( self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model) self.generator.start() else: - self.generator = Llama.build(self.config, model_id, llama_model) + self.generator = Llama3.build(self.config, model_id, llama_model) self.model_id = model_id self.llama_model = llama_model diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/llama3/generation.py similarity index 93% rename from llama_stack/providers/inline/inference/meta_reference/generation.py rename to llama_stack/providers/inline/inference/meta_reference/llama3/generation.py index f76b5a448..206ee4f7b 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/llama3/generation.py @@ -24,7 +24,6 @@ from fairscale.nn.model_parallel.initialize import ( model_parallel_is_initialized, ) from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData -from pydantic import BaseModel from llama_stack.apis.inference import ( Fp8QuantizationConfig, @@ -32,7 +31,6 @@ from llama_stack.apis.inference import ( ResponseFormat, ResponseFormatType, ) -from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.models.llama.datatypes import ( GreedySamplingStrategy, Model, @@ -47,36 +45,16 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( CompletionRequestWithRawContent, ) -from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig -from .llama3.args import ModelArgs -from .llama3.model import Transformer -from .llama3.multimodal.model import CrossAttentionTransformer +from ..common import TokenResult, model_checkpoint_dir +from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig +from .args import ModelArgs +from .model import Transformer +from .multimodal.model import CrossAttentionTransformer log = logging.getLogger(__name__) -def model_checkpoint_dir(model_id) -> str: - checkpoint_dir = Path(model_local_dir(model_id)) - - paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]] - if not any(p.exists() for p in paths): - checkpoint_dir = checkpoint_dir / "original" - - assert checkpoint_dir.exists(), ( - f"Could not find checkpoints in: {model_local_dir(model_id)}. " - f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`" - f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}" - ) - return str(checkpoint_dir) - - -class TokenResult(BaseModel): - token: int - text: str - logprobs: Optional[List[float]] = None - - -class Llama: +class Llama3: @staticmethod def build( config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig], @@ -168,7 +146,7 @@ class Llama: if isinstance(config, MetaReferenceQuantizedInferenceConfig): if isinstance(config.quantization, Fp8QuantizationConfig): - from .quantization.loader import convert_to_fp8_quantized_model + from ..quantization.loader import convert_to_fp8_quantized_model # load on CPU in bf16 so that fp8 conversion does not find an # unexpected (fp32, e.g.) datatype @@ -181,7 +159,7 @@ class Llama: model.load_state_dict(state_dict, strict=False) model = convert_to_fp8_quantized_model(model, config, ckpt_dir) elif isinstance(config.quantization, Int4QuantizationConfig): - from .quantization.loader import convert_to_int4_quantized_model + from ..quantization.loader import convert_to_int4_quantized_model model = Transformer(model_args) model = convert_to_int4_quantized_model(model, model_args, config) @@ -191,7 +169,7 @@ class Llama: # Add a wrapper for adding hadamard transform for spinquant. # This needs to be done after loading the state dict otherwise an error will be raised while # loading the state dict. - from .quantization.hadamard_utils import ( + from ..quantization.hadamard_utils import ( add_hadamard_transform_for_spinquant, ) @@ -220,7 +198,7 @@ class Llama: model.to(device) log.info(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, model_args, llama_model_id) + return Llama3(model, tokenizer, model_args, llama_model_id) def __init__( self, diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 91d0445ab..738f9ddcd 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -36,7 +36,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( CompletionRequestWithRawContent, ) -from .generation import TokenResult +from .common import TokenResult log = logging.getLogger(__name__)