refactor: move generation.py to llama3

This commit is contained in:
Ashwin Bharambe 2025-03-03 13:38:06 -08:00
parent 725423c95c
commit 02066591b8
4 changed files with 46 additions and 35 deletions

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel
from llama_stack.distribution.utils.model_utils import model_local_dir
class TokenResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
def model_checkpoint_dir(model_id) -> str:
checkpoint_dir = Path(model_local_dir(model_id))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)

View file

@ -55,7 +55,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
) )
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama from .llama3.generation import Llama3
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -83,7 +83,7 @@ class MetaReferenceInferenceImpl(
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model) self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
self.generator.start() self.generator.start()
else: else:
self.generator = Llama.build(self.config, model_id, llama_model) self.generator = Llama3.build(self.config, model_id, llama_model)
self.model_id = model_id self.model_id = model_id
self.llama_model = llama_model self.llama_model = llama_model

View file

@ -24,7 +24,6 @@ from fairscale.nn.model_parallel.initialize import (
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Fp8QuantizationConfig, Fp8QuantizationConfig,
@ -32,7 +31,6 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
) )
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy, GreedySamplingStrategy,
Model, Model,
@ -47,36 +45,16 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
) )
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig from ..common import TokenResult, model_checkpoint_dir
from .llama3.args import ModelArgs from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .llama3.model import Transformer from .args import ModelArgs
from .llama3.multimodal.model import CrossAttentionTransformer from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def model_checkpoint_dir(model_id) -> str: class Llama3:
checkpoint_dir = Path(model_local_dir(model_id))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)
class TokenResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
class Llama:
@staticmethod @staticmethod
def build( def build(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig], config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
@ -168,7 +146,7 @@ class Llama:
if isinstance(config, MetaReferenceQuantizedInferenceConfig): if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig): if isinstance(config.quantization, Fp8QuantizationConfig):
from .quantization.loader import convert_to_fp8_quantized_model from ..quantization.loader import convert_to_fp8_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an # load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype # unexpected (fp32, e.g.) datatype
@ -181,7 +159,7 @@ class Llama:
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
model = convert_to_fp8_quantized_model(model, config, ckpt_dir) model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
elif isinstance(config.quantization, Int4QuantizationConfig): elif isinstance(config.quantization, Int4QuantizationConfig):
from .quantization.loader import convert_to_int4_quantized_model from ..quantization.loader import convert_to_int4_quantized_model
model = Transformer(model_args) model = Transformer(model_args)
model = convert_to_int4_quantized_model(model, model_args, config) model = convert_to_int4_quantized_model(model, model_args, config)
@ -191,7 +169,7 @@ class Llama:
# Add a wrapper for adding hadamard transform for spinquant. # Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while # This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict. # loading the state dict.
from .quantization.hadamard_utils import ( from ..quantization.hadamard_utils import (
add_hadamard_transform_for_spinquant, add_hadamard_transform_for_spinquant,
) )
@ -220,7 +198,7 @@ class Llama:
model.to(device) model.to(device)
log.info(f"Loaded in {time.time() - start_time:.2f} seconds") log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model_id) return Llama3(model, tokenizer, model_args, llama_model_id)
def __init__( def __init__(
self, self,

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
) )
from .generation import TokenResult from .common import TokenResult
log = logging.getLogger(__name__) log = logging.getLogger(__name__)