refactor: move generation.py to llama3

This commit is contained in:
Ashwin Bharambe 2025-03-03 13:38:06 -08:00
parent 725423c95c
commit 02066591b8
4 changed files with 46 additions and 35 deletions

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel
from llama_stack.distribution.utils.model_utils import model_local_dir
class TokenResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
def model_checkpoint_dir(model_id) -> str:
checkpoint_dir = Path(model_local_dir(model_id))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)

View file

@ -55,7 +55,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .llama3.generation import Llama3
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
@ -83,7 +83,7 @@ class MetaReferenceInferenceImpl(
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
self.generator.start()
else:
self.generator = Llama.build(self.config, model_id, llama_model)
self.generator = Llama3.build(self.config, model_id, llama_model)
self.model_id = model_id
self.llama_model = llama_model

View file

@ -24,7 +24,6 @@ from fairscale.nn.model_parallel.initialize import (
model_parallel_is_initialized,
)
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
@ -32,7 +31,6 @@ from llama_stack.apis.inference import (
ResponseFormat,
ResponseFormatType,
)
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
Model,
@ -47,36 +45,16 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
CompletionRequestWithRawContent,
)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .llama3.args import ModelArgs
from .llama3.model import Transformer
from .llama3.multimodal.model import CrossAttentionTransformer
from ..common import TokenResult, model_checkpoint_dir
from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .args import ModelArgs
from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
log = logging.getLogger(__name__)
def model_checkpoint_dir(model_id) -> str:
checkpoint_dir = Path(model_local_dir(model_id))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)
class TokenResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
class Llama:
class Llama3:
@staticmethod
def build(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
@ -168,7 +146,7 @@ class Llama:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
from .quantization.loader import convert_to_fp8_quantized_model
from ..quantization.loader import convert_to_fp8_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
@ -181,7 +159,7 @@ class Llama:
model.load_state_dict(state_dict, strict=False)
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
elif isinstance(config.quantization, Int4QuantizationConfig):
from .quantization.loader import convert_to_int4_quantized_model
from ..quantization.loader import convert_to_int4_quantized_model
model = Transformer(model_args)
model = convert_to_int4_quantized_model(model, model_args, config)
@ -191,7 +169,7 @@ class Llama:
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.
from .quantization.hadamard_utils import (
from ..quantization.hadamard_utils import (
add_hadamard_transform_for_spinquant,
)
@ -220,7 +198,7 @@ class Llama:
model.to(device)
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model_id)
return Llama3(model, tokenizer, model_args, llama_model_id)
def __init__(
self,

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
CompletionRequestWithRawContent,
)
from .generation import TokenResult
from .common import TokenResult
log = logging.getLogger(__name__)