Merge branch 'main' into evals_new

This commit is contained in:
Xi Yan 2024-10-15 10:20:03 -07:00 committed by GitHub
commit 2c23a66300
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 112 additions and 120 deletions

View file

@ -7,10 +7,11 @@
from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance(
config, DatabricksImplConfig
), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()
return impl
return impl

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -19,4 +18,4 @@ class DatabricksImplConfig(BaseModel):
api_token: str = Field(
default=None,
description="The Databricks API token",
)
)

View file

@ -48,7 +48,14 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
def completion(self, request: CompletionRequest) -> AsyncGenerator:
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(

View file

@ -63,13 +63,12 @@ async def llm_rag_query_generator(
model = config.model
message = UserMessage(content=content)
response = inference_api.chat_completion(
response = await inference_api.chat_completion(
model=model,
messages=[message],
stream=False,
)
async for chunk in response:
query = chunk.completion_message.content
query = response.completion_message.content
return query

View file

@ -16,7 +16,7 @@ from llama_stack.apis.agents import * # noqa: F403
from ..agents import (
AGENT_INSTANCES_BY_ID,
MetaReferenceAgentsImpl,
MetaReferenceImplConfig,
MetaReferenceInferenceConfig,
)
@ -166,7 +166,7 @@ def mock_memory_api():
@pytest.fixture
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
impl = MetaReferenceAgentsImpl(
config=MetaReferenceImplConfig(),
config=MetaReferenceInferenceConfig(),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
memory_api=mock_memory_api,

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa
from typing import Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps,
):
from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -15,12 +15,11 @@ from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceImplConfig(BaseModel):
class MetaReferenceInferenceConfig(BaseModel):
model: str = Field(
default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`",
)
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
@ -38,9 +37,9 @@ class MetaReferenceImplConfig(BaseModel):
@property
def model_parallel_size(self) -> int:
# HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
resolved = resolve_model(self.model)
assert resolved is not None
return resolved.pth_file_count
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig

View file

@ -11,9 +11,8 @@ import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional
from typing import Generator, List, Optional, Union
import torch
import torch.nn.functional as F
@ -36,14 +35,12 @@ from llama_models.llama3.reference_impl.multimodal.model import (
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir
from pydantic import BaseModel
from termcolor import cprint
from .config import MetaReferenceImplConfig
from llama_stack.distribution.utils.model_utils import model_local_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
def model_checkpoint_dir(model) -> str:
@ -68,7 +65,11 @@ class TokenResult(BaseModel):
class Llama:
@staticmethod
def build(config: MetaReferenceImplConfig):
def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
]
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
@ -78,15 +79,6 @@ class Llama:
"""
model = resolve_model(config.model)
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
@ -134,12 +126,7 @@ class Llama:
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if fp8:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
from .quantization.loader import convert_to_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceImplConfig
from .config import MetaReferenceInferenceConfig
from .model_parallel import LlamaModelParallelGenerator
# there's a single model parallel process running serving the model. for now,
@ -26,7 +26,7 @@ SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceImplConfig) -> None:
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
if model is None:

View file

@ -6,7 +6,6 @@
import os
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
@ -15,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
@ -36,7 +35,7 @@ class ModelRunner:
)
def init_model_cb(config: MetaReferenceImplConfig):
def init_model_cb(config: MetaReferenceInferenceConfig):
llama = Llama.build(config)
return ModelRunner(llama)
@ -52,7 +51,7 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager.
"""
def __init__(self, config: MetaReferenceImplConfig):
def __init__(self, config: MetaReferenceInferenceConfig):
self.config = config
self.model = resolve_model(self.config.model)
# this is a hack because Agent's loop uses this to tokenize and check if input is too long

View file

@ -11,7 +11,7 @@ import tempfile
import time
import uuid
from enum import Enum
from typing import Any, Callable, Generator, List, Literal, Optional, Union
from typing import Callable, Generator, List, Literal, Optional, Union
import torch
@ -317,7 +317,7 @@ def start_model_parallel_process(
request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv()
print(f"Finished model load {response}")
print("Loaded model...")
return request_socket, process

View file

@ -22,19 +22,10 @@ from torch import Tensor
from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceImplConfig,
MetaReferenceQuantizedInferenceConfig,
)
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except ImportError:
return False
def swiglu_wrapper(
self,
x: Tensor,
@ -47,7 +38,7 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer,
config: MetaReferenceImplConfig,
config: MetaReferenceQuantizedInferenceConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:

View file

@ -170,7 +170,7 @@ class LlamaGuardShield(ShieldBase):
for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role:
raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
)
return messages

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import VLLMConfig

View file

@ -14,6 +14,21 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference-quantized",
pip_packages=[
"accelerate",
"blobfile",
@ -25,7 +40,7 @@ def available_providers() -> List[ProviderSpec]:
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
),
remote_provider_spec(
api=Api.inference,