Split off meta-reference-quantized provider

This commit is contained in:
Ashwin Bharambe 2024-10-10 15:54:08 -07:00
parent 7ff5800dea
commit 1ff0476002
10 changed files with 54 additions and 58 deletions

View file

@ -149,6 +149,7 @@ class StackBuild(Subcommand):
def _run_template_list_cmd(self, args: argparse.Namespace) -> None: def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
import json import json
from llama_stack.cli.table import print_table from llama_stack.cli.table import print_table
# eventually, this should query a registry at llama.meta.com/llamastack/distributions # eventually, this should query a registry at llama.meta.com/llamastack/distributions
@ -175,6 +176,7 @@ class StackBuild(Subcommand):
def _run_stack_build_command(self, args: argparse.Namespace) -> None: def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import textwrap import textwrap
import yaml import yaml
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt from prompt_toolkit import prompt
@ -256,7 +258,9 @@ class StackBuild(Subcommand):
providers = dict() providers = dict()
for api, providers_for_api in get_provider_registry().items(): for api, providers_for_api in get_provider_registry().items():
available_providers = [ available_providers = [
x for x in providers_for_api.keys() if x != "remote" x
for x in providers_for_api.keys()
if x not in ("remote", "remote::sample")
] ]
api_provider = prompt( api_provider = prompt(
"> Enter provider for API {}: ".format(api.value), "> Enter provider for API {}: ".format(api.value),

View file

@ -16,7 +16,7 @@ from llama_stack.apis.agents import * # noqa: F403
from ..agents import ( from ..agents import (
AGENT_INSTANCES_BY_ID, AGENT_INSTANCES_BY_ID,
MetaReferenceAgentsImpl, MetaReferenceAgentsImpl,
MetaReferenceImplConfig, MetaReferenceInferenceConfig,
) )
@ -166,7 +166,7 @@ def mock_memory_api():
@pytest.fixture @pytest.fixture
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api): async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
config=MetaReferenceImplConfig(), config=MetaReferenceInferenceConfig(),
inference_api=mock_inference_api, inference_api=mock_inference_api,
safety_api=mock_safety_api, safety_api=mock_safety_api,
memory_api=mock_memory_api, memory_api=mock_memory_api,

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa from typing import Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl(config: MetaReferenceImplConfig, _deps): async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps,
):
from .inference import MetaReferenceInferenceImpl from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config) impl = MetaReferenceInferenceImpl(config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -15,12 +15,11 @@ from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceImplConfig(BaseModel): class MetaReferenceInferenceConfig(BaseModel):
model: str = Field( model: str = Field(
default="Llama3.1-8B-Instruct", default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
) )
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
max_seq_len: int = 4096 max_seq_len: int = 4096
max_batch_size: int = 1 max_batch_size: int = 1
@ -38,9 +37,9 @@ class MetaReferenceImplConfig(BaseModel):
@property @property
def model_parallel_size(self) -> int: def model_parallel_size(self) -> int:
# HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
resolved = resolve_model(self.model) resolved = resolve_model(self.model)
assert resolved is not None
return resolved.pth_file_count return resolved.pth_file_count
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig

View file

@ -11,9 +11,8 @@ import json
import os import os
import sys import sys
import time import time
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Generator, List, Optional from typing import Generator, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -36,14 +35,12 @@ from llama_models.llama3.reference_impl.multimodal.model import (
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from .config import MetaReferenceImplConfig from llama_stack.distribution.utils.model_utils import model_local_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
@ -68,7 +65,11 @@ class TokenResult(BaseModel):
class Llama: class Llama:
@staticmethod @staticmethod
def build(config: MetaReferenceImplConfig): def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
]
):
""" """
Build a Llama instance by initializing and loading a model checkpoint. Build a Llama instance by initializing and loading a model checkpoint.
@ -78,15 +79,6 @@ class Llama:
""" """
model = resolve_model(config.model) model = resolve_model(config.model)
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl") torch.distributed.init_process_group("nccl")
@ -134,12 +126,7 @@ class Llama:
model_args.vocab_size == tokenizer.n_words model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = ( if isinstance(config, MetaReferenceQuantizedInferenceConfig):
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if fp8:
from .quantization.loader import convert_to_quantized_model from .quantization.loader import convert_to_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an # load on CPU in bf16 so that fp8 conversion does not find an

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages, chat_completion_request_to_messages,
) )
from .config import MetaReferenceImplConfig from .config import MetaReferenceInferenceConfig
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
# there's a single model parallel process running serving the model. for now, # there's a single model parallel process running serving the model. for now,
@ -26,7 +26,7 @@ SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceImplConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model) model = resolve_model(config.model)
if model is None: if model is None:

View file

@ -6,7 +6,6 @@
import os import os
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Generator, List, Optional from typing import Generator, List, Optional
@ -15,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir from .generation import Llama, model_checkpoint_dir
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
@ -36,7 +35,7 @@ class ModelRunner:
) )
def init_model_cb(config: MetaReferenceImplConfig): def init_model_cb(config: MetaReferenceInferenceConfig):
llama = Llama.build(config) llama = Llama.build(config)
return ModelRunner(llama) return ModelRunner(llama)
@ -52,7 +51,7 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager. clear at the callsite why we need to use a context manager.
""" """
def __init__(self, config: MetaReferenceImplConfig): def __init__(self, config: MetaReferenceInferenceConfig):
self.config = config self.config = config
self.model = resolve_model(self.config.model) self.model = resolve_model(self.config.model)
# this is a hack because Agent's loop uses this to tokenize and check if input is too long # this is a hack because Agent's loop uses this to tokenize and check if input is too long

View file

@ -11,7 +11,7 @@ import tempfile
import time import time
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Any, Callable, Generator, List, Literal, Optional, Union from typing import Callable, Generator, List, Literal, Optional, Union
import torch import torch
@ -317,7 +317,7 @@ def start_model_parallel_process(
request_socket.send(encode_msg(ReadyRequest())) request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv() response = request_socket.recv()
print(f"Finished model load {response}") print("Loaded model...")
return request_socket, process return request_socket, process

View file

@ -22,19 +22,10 @@ from torch import Tensor
from llama_stack.apis.inference import QuantizationType from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.impls.meta_reference.inference.config import ( from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceImplConfig, MetaReferenceQuantizedInferenceConfig,
) )
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except ImportError:
return False
def swiglu_wrapper( def swiglu_wrapper(
self, self,
x: Tensor, x: Tensor,
@ -47,7 +38,7 @@ def swiglu_wrapper(
def convert_to_quantized_model( def convert_to_quantized_model(
model: Transformer, model: Transformer,
config: MetaReferenceImplConfig, config: MetaReferenceQuantizedInferenceConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0, fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer: ) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value: if config.quantization.type == QuantizationType.bf16.value:

View file

@ -14,6 +14,21 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,
provider_type="meta-reference", provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference-quantized",
pip_packages=[ pip_packages=[
"accelerate", "accelerate",
"blobfile", "blobfile",
@ -25,7 +40,7 @@ def available_providers() -> List[ProviderSpec]:
"zmq", "zmq",
], ],
module="llama_stack.providers.impls.meta_reference.inference", module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig", config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
), ),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,