mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Split off meta-reference-quantized provider
This commit is contained in:
parent
7ff5800dea
commit
1ff0476002
10 changed files with 54 additions and 58 deletions
|
@ -149,6 +149,7 @@ class StackBuild(Subcommand):
|
|||
|
||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
import json
|
||||
|
||||
from llama_stack.cli.table import print_table
|
||||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
|
@ -175,6 +176,7 @@ class StackBuild(Subcommand):
|
|||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
import textwrap
|
||||
|
||||
import yaml
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from prompt_toolkit import prompt
|
||||
|
@ -256,7 +258,9 @@ class StackBuild(Subcommand):
|
|||
providers = dict()
|
||||
for api, providers_for_api in get_provider_registry().items():
|
||||
available_providers = [
|
||||
x for x in providers_for_api.keys() if x != "remote"
|
||||
x
|
||||
for x in providers_for_api.keys()
|
||||
if x not in ("remote", "remote::sample")
|
||||
]
|
||||
api_provider = prompt(
|
||||
"> Enter provider for API {}: ".format(api.value),
|
||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.agents import * # noqa: F403
|
|||
from ..agents import (
|
||||
AGENT_INSTANCES_BY_ID,
|
||||
MetaReferenceAgentsImpl,
|
||||
MetaReferenceImplConfig,
|
||||
MetaReferenceInferenceConfig,
|
||||
)
|
||||
|
||||
|
||||
|
@ -166,7 +166,7 @@ def mock_memory_api():
|
|||
@pytest.fixture
|
||||
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config=MetaReferenceImplConfig(),
|
||||
config=MetaReferenceInferenceConfig(),
|
||||
inference_api=mock_inference_api,
|
||||
safety_api=mock_safety_api,
|
||||
memory_api=mock_memory_api,
|
||||
|
|
|
@ -4,16 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import MetaReferenceImplConfig # noqa
|
||||
from typing import Union
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
|
||||
async def get_provider_impl(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
_deps,
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -15,12 +15,11 @@ from pydantic import BaseModel, Field, field_validator
|
|||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
class MetaReferenceInferenceConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Llama3.1-8B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
|
@ -38,9 +37,9 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
|
||||
@property
|
||||
def model_parallel_size(self) -> int:
|
||||
# HACK ALERT: this will be fixed when we move inference configuration
|
||||
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
|
||||
# as configuration there
|
||||
resolved = resolve_model(self.model)
|
||||
assert resolved is not None
|
||||
return resolved.pth_file_count
|
||||
|
||||
|
||||
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
|
||||
quantization: QuantizationConfig
|
||||
|
|
|
@ -11,9 +11,8 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -36,14 +35,12 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
|||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
|
@ -68,7 +65,11 @@ class TokenResult(BaseModel):
|
|||
|
||||
class Llama:
|
||||
@staticmethod
|
||||
def build(config: MetaReferenceImplConfig):
|
||||
def build(
|
||||
config: Union[
|
||||
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
]
|
||||
):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
||||
|
@ -78,15 +79,6 @@ class Llama:
|
|||
"""
|
||||
model = resolve_model(config.model)
|
||||
|
||||
if (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
):
|
||||
from .quantization.loader import is_fbgemm_available
|
||||
|
||||
if not is_fbgemm_available():
|
||||
raise ImportError("fbgemm-gpu is required for FP8 quantization")
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
|
@ -134,12 +126,7 @@ class Llama:
|
|||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
fp8 = (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
)
|
||||
|
||||
if fp8:
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
|
@ -26,7 +26,7 @@ SEMAPHORE = asyncio.Semaphore(1)
|
|||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
if model is None:
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
|
@ -15,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
|||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
|
||||
|
||||
|
@ -36,7 +35,7 @@ class ModelRunner:
|
|||
)
|
||||
|
||||
|
||||
def init_model_cb(config: MetaReferenceImplConfig):
|
||||
def init_model_cb(config: MetaReferenceInferenceConfig):
|
||||
llama = Llama.build(config)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
@ -52,7 +51,7 @@ class LlamaModelParallelGenerator:
|
|||
clear at the callsite why we need to use a context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MetaReferenceImplConfig):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig):
|
||||
self.config = config
|
||||
self.model = resolve_model(self.config.model)
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
|
|
|
@ -11,7 +11,7 @@ import tempfile
|
|||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Generator, List, Literal, Optional, Union
|
||||
from typing import Callable, Generator, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -317,7 +317,7 @@ def start_model_parallel_process(
|
|||
|
||||
request_socket.send(encode_msg(ReadyRequest()))
|
||||
response = request_socket.recv()
|
||||
print(f"Finished model load {response}")
|
||||
print("Loaded model...")
|
||||
|
||||
return request_socket, process
|
||||
|
||||
|
|
|
@ -22,19 +22,10 @@ from torch import Tensor
|
|||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
MetaReferenceImplConfig,
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
|
||||
|
||||
def is_fbgemm_available() -> bool:
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor,
|
||||
|
@ -47,7 +38,7 @@ def swiglu_wrapper(
|
|||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
config: MetaReferenceImplConfig,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
) -> Transformer:
|
||||
if config.quantization.type == QuantizationType.bf16.value:
|
||||
|
|
|
@ -14,6 +14,21 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[
|
||||
"accelerate",
|
||||
"blobfile",
|
||||
"fairscale",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"zmq",
|
||||
],
|
||||
module="llama_stack.providers.impls.meta_reference.inference",
|
||||
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="meta-reference-quantized",
|
||||
pip_packages=[
|
||||
"accelerate",
|
||||
"blobfile",
|
||||
|
@ -25,7 +40,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"zmq",
|
||||
],
|
||||
module="llama_stack.providers.impls.meta_reference.inference",
|
||||
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig",
|
||||
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue