mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
refactor: move generation.py to llama3
This commit is contained in:
parent
725423c95c
commit
02066591b8
4 changed files with 46 additions and 35 deletions
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResult(BaseModel):
|
||||||
|
token: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[List[float]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def model_checkpoint_dir(model_id) -> str:
|
||||||
|
checkpoint_dir = Path(model_local_dir(model_id))
|
||||||
|
|
||||||
|
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||||
|
if not any(p.exists() for p in paths):
|
||||||
|
checkpoint_dir = checkpoint_dir / "original"
|
||||||
|
|
||||||
|
assert checkpoint_dir.exists(), (
|
||||||
|
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||||
|
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
||||||
|
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
||||||
|
)
|
||||||
|
return str(checkpoint_dir)
|
|
@ -55,7 +55,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .llama3.generation import Llama3
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -83,7 +83,7 @@ class MetaReferenceInferenceImpl(
|
||||||
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
|
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
else:
|
else:
|
||||||
self.generator = Llama.build(self.config, model_id, llama_model)
|
self.generator = Llama3.build(self.config, model_id, llama_model)
|
||||||
|
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.llama_model = llama_model
|
self.llama_model = llama_model
|
||||||
|
|
|
@ -24,7 +24,6 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Fp8QuantizationConfig,
|
Fp8QuantizationConfig,
|
||||||
|
@ -32,7 +31,6 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
GreedySamplingStrategy,
|
GreedySamplingStrategy,
|
||||||
Model,
|
Model,
|
||||||
|
@ -47,36 +45,16 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
from ..common import TokenResult, model_checkpoint_dir
|
||||||
from .llama3.args import ModelArgs
|
from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
from .llama3.model import Transformer
|
from .args import ModelArgs
|
||||||
from .llama3.multimodal.model import CrossAttentionTransformer
|
from .model import Transformer
|
||||||
|
from .multimodal.model import CrossAttentionTransformer
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model_id) -> str:
|
class Llama3:
|
||||||
checkpoint_dir = Path(model_local_dir(model_id))
|
|
||||||
|
|
||||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
|
||||||
if not any(p.exists() for p in paths):
|
|
||||||
checkpoint_dir = checkpoint_dir / "original"
|
|
||||||
|
|
||||||
assert checkpoint_dir.exists(), (
|
|
||||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
|
||||||
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
|
||||||
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
|
||||||
)
|
|
||||||
return str(checkpoint_dir)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenResult(BaseModel):
|
|
||||||
token: int
|
|
||||||
text: str
|
|
||||||
logprobs: Optional[List[float]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Llama:
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build(
|
def build(
|
||||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||||
|
@ -168,7 +146,7 @@ class Llama:
|
||||||
|
|
||||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||||
from .quantization.loader import convert_to_fp8_quantized_model
|
from ..quantization.loader import convert_to_fp8_quantized_model
|
||||||
|
|
||||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||||
# unexpected (fp32, e.g.) datatype
|
# unexpected (fp32, e.g.) datatype
|
||||||
|
@ -181,7 +159,7 @@ class Llama:
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
||||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||||
from .quantization.loader import convert_to_int4_quantized_model
|
from ..quantization.loader import convert_to_int4_quantized_model
|
||||||
|
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||||
|
@ -191,7 +169,7 @@ class Llama:
|
||||||
# Add a wrapper for adding hadamard transform for spinquant.
|
# Add a wrapper for adding hadamard transform for spinquant.
|
||||||
# This needs to be done after loading the state dict otherwise an error will be raised while
|
# This needs to be done after loading the state dict otherwise an error will be raised while
|
||||||
# loading the state dict.
|
# loading the state dict.
|
||||||
from .quantization.hadamard_utils import (
|
from ..quantization.hadamard_utils import (
|
||||||
add_hadamard_transform_for_spinquant,
|
add_hadamard_transform_for_spinquant,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -220,7 +198,7 @@ class Llama:
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
return Llama(model, tokenizer, model_args, llama_model_id)
|
return Llama3(model, tokenizer, model_args, llama_model_id)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .generation import TokenResult
|
from .common import TokenResult
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue