From 1ff0476002de90a1de54af9f1e4f4b9c75fc91b8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 10 Oct 2024 15:54:08 -0700 Subject: [PATCH] Split off meta-reference-quantized provider --- llama_stack/cli/stack/build.py | 6 +++- .../agents/tests/test_chat_agent.py | 4 +-- .../meta_reference/inference/__init__.py | 13 ++++---- .../impls/meta_reference/inference/config.py | 11 +++---- .../meta_reference/inference/generation.py | 33 ++++++------------- .../meta_reference/inference/inference.py | 4 +-- .../inference/model_parallel.py | 7 ++-- .../inference/parallel_utils.py | 4 +-- .../inference/quantization/loader.py | 13 ++------ llama_stack/providers/registry/inference.py | 17 +++++++++- 10 files changed, 54 insertions(+), 58 deletions(-) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 3fe615e6e..3c59e8c20 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -149,6 +149,7 @@ class StackBuild(Subcommand): def _run_template_list_cmd(self, args: argparse.Namespace) -> None: import json + from llama_stack.cli.table import print_table # eventually, this should query a registry at llama.meta.com/llamastack/distributions @@ -175,6 +176,7 @@ class StackBuild(Subcommand): def _run_stack_build_command(self, args: argparse.Namespace) -> None: import textwrap + import yaml from llama_stack.distribution.distribution import get_provider_registry from prompt_toolkit import prompt @@ -256,7 +258,9 @@ class StackBuild(Subcommand): providers = dict() for api, providers_for_api in get_provider_registry().items(): available_providers = [ - x for x in providers_for_api.keys() if x != "remote" + x + for x in providers_for_api.keys() + if x not in ("remote", "remote::sample") ] api_provider = prompt( "> Enter provider for API {}: ".format(api.value), diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py index 9d941edc9..46423814b 100644 --- a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py +++ b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py @@ -16,7 +16,7 @@ from llama_stack.apis.agents import * # noqa: F403 from ..agents import ( AGENT_INSTANCES_BY_ID, MetaReferenceAgentsImpl, - MetaReferenceImplConfig, + MetaReferenceInferenceConfig, ) @@ -166,7 +166,7 @@ def mock_memory_api(): @pytest.fixture async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api): impl = MetaReferenceAgentsImpl( - config=MetaReferenceImplConfig(), + config=MetaReferenceInferenceConfig(), inference_api=mock_inference_api, safety_api=mock_safety_api, memory_api=mock_memory_api, diff --git a/llama_stack/providers/impls/meta_reference/inference/__init__.py b/llama_stack/providers/impls/meta_reference/inference/__init__.py index 64d315e79..9c923490d 100644 --- a/llama_stack/providers/impls/meta_reference/inference/__init__.py +++ b/llama_stack/providers/impls/meta_reference/inference/__init__.py @@ -4,16 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import MetaReferenceImplConfig # noqa +from typing import Union + +from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig -async def get_provider_impl(config: MetaReferenceImplConfig, _deps): +async def get_provider_impl( + config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig], + _deps, +): from .inference import MetaReferenceInferenceImpl - assert isinstance( - config, MetaReferenceImplConfig - ), f"Unexpected config type: {type(config)}" - impl = MetaReferenceInferenceImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/impls/meta_reference/inference/config.py b/llama_stack/providers/impls/meta_reference/inference/config.py index ba5eddd53..901a8c7fb 100644 --- a/llama_stack/providers/impls/meta_reference/inference/config.py +++ b/llama_stack/providers/impls/meta_reference/inference/config.py @@ -15,12 +15,11 @@ from pydantic import BaseModel, Field, field_validator from llama_stack.providers.utils.inference import supported_inference_models -class MetaReferenceImplConfig(BaseModel): +class MetaReferenceInferenceConfig(BaseModel): model: str = Field( default="Llama3.1-8B-Instruct", description="Model descriptor from `llama model list`", ) - quantization: Optional[QuantizationConfig] = None torch_seed: Optional[int] = None max_seq_len: int = 4096 max_batch_size: int = 1 @@ -38,9 +37,9 @@ class MetaReferenceImplConfig(BaseModel): @property def model_parallel_size(self) -> int: - # HACK ALERT: this will be fixed when we move inference configuration - # to ModelsRegistry and we can explicitly ask for `model_parallel_size` - # as configuration there resolved = resolve_model(self.model) - assert resolved is not None return resolved.pth_file_count + + +class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig): + quantization: QuantizationConfig diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 37aef5ede..8d94a20d1 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -11,9 +11,8 @@ import json import os import sys import time -from dataclasses import dataclass from pathlib import Path -from typing import Generator, List, Optional +from typing import Generator, List, Optional, Union import torch import torch.nn.functional as F @@ -36,14 +35,12 @@ from llama_models.llama3.reference_impl.multimodal.model import ( ) from llama_models.sku_list import resolve_model -from llama_stack.apis.inference import QuantizationType - -from llama_stack.distribution.utils.model_utils import model_local_dir - from pydantic import BaseModel from termcolor import cprint -from .config import MetaReferenceImplConfig +from llama_stack.distribution.utils.model_utils import model_local_dir + +from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig def model_checkpoint_dir(model) -> str: @@ -68,7 +65,11 @@ class TokenResult(BaseModel): class Llama: @staticmethod - def build(config: MetaReferenceImplConfig): + def build( + config: Union[ + MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig + ] + ): """ Build a Llama instance by initializing and loading a model checkpoint. @@ -78,15 +79,6 @@ class Llama: """ model = resolve_model(config.model) - if ( - config.quantization - and config.quantization.type == QuantizationType.fp8.value - ): - from .quantization.loader import is_fbgemm_available - - if not is_fbgemm_available(): - raise ImportError("fbgemm-gpu is required for FP8 quantization") - if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") @@ -134,12 +126,7 @@ class Llama: model_args.vocab_size == tokenizer.n_words ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" - fp8 = ( - config.quantization - and config.quantization.type == QuantizationType.fp8.value - ) - - if fp8: + if isinstance(config, MetaReferenceQuantizedInferenceConfig): from .quantization.loader import convert_to_quantized_model # load on CPU in bf16 so that fp8 conversion does not find an diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index a8afcea54..6696762c9 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) -from .config import MetaReferenceImplConfig +from .config import MetaReferenceInferenceConfig from .model_parallel import LlamaModelParallelGenerator # there's a single model parallel process running serving the model. for now, @@ -26,7 +26,7 @@ SEMAPHORE = asyncio.Semaphore(1) class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): - def __init__(self, config: MetaReferenceImplConfig) -> None: + def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config model = resolve_model(config.model) if model is None: diff --git a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py index 798fadcbe..e8f483f30 100644 --- a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py +++ b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py @@ -6,7 +6,6 @@ import os from copy import deepcopy -from dataclasses import dataclass from functools import partial from typing import Generator, List, Optional @@ -15,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model -from .config import MetaReferenceImplConfig +from .config import MetaReferenceInferenceConfig from .generation import Llama, model_checkpoint_dir from .parallel_utils import InferenceArgs, ModelParallelProcessGroup @@ -36,7 +35,7 @@ class ModelRunner: ) -def init_model_cb(config: MetaReferenceImplConfig): +def init_model_cb(config: MetaReferenceInferenceConfig): llama = Llama.build(config) return ModelRunner(llama) @@ -52,7 +51,7 @@ class LlamaModelParallelGenerator: clear at the callsite why we need to use a context manager. """ - def __init__(self, config: MetaReferenceImplConfig): + def __init__(self, config: MetaReferenceInferenceConfig): self.config = config self.model = resolve_model(self.config.model) # this is a hack because Agent's loop uses this to tokenize and check if input is too long diff --git a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py index c6eacc73c..7dbedd0f0 100644 --- a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py +++ b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py @@ -11,7 +11,7 @@ import tempfile import time import uuid from enum import Enum -from typing import Any, Callable, Generator, List, Literal, Optional, Union +from typing import Callable, Generator, List, Literal, Optional, Union import torch @@ -317,7 +317,7 @@ def start_model_parallel_process( request_socket.send(encode_msg(ReadyRequest())) response = request_socket.recv() - print(f"Finished model load {response}") + print("Loaded model...") return request_socket, process diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py index 1df86cb84..92b3a6ce3 100644 --- a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py @@ -22,19 +22,10 @@ from torch import Tensor from llama_stack.apis.inference import QuantizationType from llama_stack.providers.impls.meta_reference.inference.config import ( - MetaReferenceImplConfig, + MetaReferenceQuantizedInferenceConfig, ) -def is_fbgemm_available() -> bool: - try: - import fbgemm_gpu.experimental.gen_ai # noqa: F401 - - return True - except ImportError: - return False - - def swiglu_wrapper( self, x: Tensor, @@ -47,7 +38,7 @@ def swiglu_wrapper( def convert_to_quantized_model( model: Transformer, - config: MetaReferenceImplConfig, + config: MetaReferenceQuantizedInferenceConfig, fp8_activation_scale_ub: Optional[float] = 1200.0, ) -> Transformer: if config.quantization.type == QuantizationType.bf16.value: diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index ddfd4ff40..686fc273b 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -14,6 +14,21 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.inference, provider_type="meta-reference", + pip_packages=[ + "accelerate", + "blobfile", + "fairscale", + "torch", + "torchvision", + "transformers", + "zmq", + ], + module="llama_stack.providers.impls.meta_reference.inference", + config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig", + ), + InlineProviderSpec( + api=Api.inference, + provider_type="meta-reference-quantized", pip_packages=[ "accelerate", "blobfile", @@ -25,7 +40,7 @@ def available_providers() -> List[ProviderSpec]: "zmq", ], module="llama_stack.providers.impls.meta_reference.inference", - config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig", + config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", ), remote_provider_spec( api=Api.inference,