mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-12 16:16:09 +00:00
Introduce Llama stack distributions (#22)
* Add distribution CLI scaffolding * More progress towards `llama distribution install` * getting closer to a distro definition, distro install + configure works * Distribution server now functioning * read existing configuration, save enums properly * Remove inference uvicorn server entrypoint and llama inference CLI command * updated dependency and client model name * Improved exception handling * local imports for faster cli * undo a typo, add a passthrough distribution * implement full-passthrough in the server * add safety adapters, configuration handling, server + clients * cleanup, moving stuff to common, nuke utils * Add a Path() wrapper at the earliest place * fixes * Bring agentic system api to toolchain Add adapter dependencies and resolve adapters using a topological sort * refactor to reduce size of `agentic_system` * move straggler files and fix some important existing bugs * ApiSurface -> Api * refactor a method out * Adapter -> Provider * Make each inference provider into its own subdirectory * installation fixes * Rename Distribution -> DistributionSpec, simplify RemoteProviders * dict key instead of attr * update inference config to take model and not model_dir * Fix passthrough streaming, send headers properly not part of body :facepalm * update safety to use model sku ids and not model dirs * Update cli_reference.md * minor fixes * add DistributionConfig, fix a bug in model download * Make install + start scripts do proper configuration automatically * Update CLI_reference * Nuke fp8_requirements, fold fbgemm into common requirements * Update README, add newline between API surface configurations * Refactor download functionality out of the Command so can be reused * Add `llama model download` alias for `llama download` * Show message about checksum file so users can check themselves * Simpler intro statements * get ollama working * Reduce a bunch of dependencies from toolchain Some improvements to the distribution install script * Avoid using `conda run` since it buffers everything * update dependencies and rely on LLAMA_TOOLCHAIN_DIR for dev purposes * add validation for configuration input * resort imports * make optional subclasses default to yes for configuration * Remove additional_pip_packages; move deps to providers * for inline make 8b model the default * Add scripts to MANIFEST * allow installing from test.pypi.org * Fix #2 to help with testing packages * Must install llama-models at that same version first * fix PIP_ARGS --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Hardik Shah <hjshah@meta.com>
This commit is contained in:
parent
da4645a27a
commit
e830814399
115 changed files with 5839 additions and 1120 deletions
8
llama_toolchain/inference/meta_reference/__init__.py
Normal file
8
llama_toolchain/inference/meta_reference/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import MetaReferenceImplConfig # noqa
|
||||
from .inference import get_provider_impl # noqa
|
43
llama_toolchain/inference/meta_reference/config.py
Normal file
43
llama_toolchain/inference/meta_reference/config.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.datatypes import ModelFamily
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_toolchain.inference.api import QuantizationConfig
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
|
||||
@validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in all_registered_models()
|
||||
if m.model_family == ModelFamily.llama3_1
|
||||
]
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
raise ValueError(
|
||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||
)
|
||||
return model
|
335
llama_toolchain/inference/meta_reference/generation.py
Normal file
335
llama_toolchain/inference/meta_reference/generation.py
Normal file
|
@ -0,0 +1,335 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.llama3_1.api.args import ModelArgs
|
||||
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3_1.reference_impl.model import Transformer
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
from llama_toolchain.inference.api import QuantizationType
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model))
|
||||
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoint dir: {checkpoint_dir}."
|
||||
f"Please download model using `llama download {model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenResult:
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
|
||||
class Llama:
|
||||
@staticmethod
|
||||
def build(config: MetaReferenceImplConfig):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
||||
Note:
|
||||
This method initializes the distributed process group, sets the device to CUDA,
|
||||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
model = resolve_model(config.model)
|
||||
|
||||
if (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
):
|
||||
from .quantization.loader import is_fbgemm_available
|
||||
|
||||
if not is_fbgemm_available():
|
||||
raise ImportError("fbgemm-gpu is required for FP8 quantization")
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
model_parallel_size = model.hardware_requirements.gpu_count
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# seed must be the same in all processes
|
||||
if config.torch_seed is not None:
|
||||
torch.manual_seed(config.torch_seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(
|
||||
checkpoints
|
||||
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
# TODO(ashwin): this block is so we can load internal checkpoints without additional
|
||||
# fuss. the final code should _not_ have this blurb
|
||||
if "model" in params:
|
||||
params = params["model"]
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
**params,
|
||||
)
|
||||
|
||||
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
fp8 = (
|
||||
config.quantization
|
||||
and config.quantization.type == QuantizationType.fp8.value
|
||||
)
|
||||
|
||||
if fp8:
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
|
||||
if fp8:
|
||||
# load on CPU first since if we are doing fp8, we probably don't
|
||||
# have enough memory on GPU for bf16
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
if not fp8:
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if config.quantization:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
model = convert_to_quantized_model(model, config)
|
||||
else:
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
return Llama(model, tokenizer, model_args)
|
||||
|
||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_input: ModelInput,
|
||||
max_gen_len: int,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
include_stop_token: bool = False,
|
||||
) -> Generator:
|
||||
params = self.model.params
|
||||
|
||||
# cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red")
|
||||
prompt_tokens = [model_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
cprint(
|
||||
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
|
||||
)
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||
|
||||
prev_pos = 0
|
||||
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
||||
input_text_mask = tokens != pad_id
|
||||
if min_prompt_len == total_len:
|
||||
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
|
||||
logits = self.model.forward(tokens, prev_pos)
|
||||
token_logprobs = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=tokens,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(
|
||||
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
||||
)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=tokens[:, prev_pos + 1 : cur_pos + 1],
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (
|
||||
torch.isin(next_token, stop_tokens)
|
||||
)
|
||||
yield TokenResult(
|
||||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist()
|
||||
if logprobs
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def text_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator:
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
or max_gen_len >= self.model.params.max_seq_len
|
||||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
prompt_tokens = self.tokenizer.encode(x, bos=True, eos=False)
|
||||
|
||||
yield from self.generate(
|
||||
model_input=ModelInput(tokens=prompt_tokens),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Message],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
) -> Generator:
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
or max_gen_len >= self.model.params.max_seq_len
|
||||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(messages),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
include_stop_token=True,
|
||||
)
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
204
llama_toolchain/inference/meta_reference/inference.py
Normal file
204
llama_toolchain/inference/meta_reference/inference.py
Normal file
|
@ -0,0 +1,204 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, Dict, Union
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import StopReason
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
Inference,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
||||
# top-level server is going to do. make the typing more specific here
|
||||
async def chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
]:
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
)
|
||||
elif model.descriptor() != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
async with SEMAPHORE:
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
||||
stop_reason = None
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=request.messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
):
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if not request.stream:
|
||||
if request.logprobs:
|
||||
logprobs.append(token_result.logprob)
|
||||
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||
# if someone breaks the iteration before coming here we are toast
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
if request.stream:
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
106
llama_toolchain/inference/meta_reference/model_parallel.py
Normal file
106
llama_toolchain/inference/meta_reference/model_parallel.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
from llama_models.llama3_1.api.chat_format import ChatFormat
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceArgs:
|
||||
messages: List[Message]
|
||||
temperature: float
|
||||
top_p: float
|
||||
max_gen_len: int
|
||||
logprobs: bool
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(self, llama):
|
||||
self.llama = llama
|
||||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, task: InferenceArgs):
|
||||
return self.llama.chat_completion(
|
||||
task.messages,
|
||||
task.temperature,
|
||||
task.top_p,
|
||||
task.max_gen_len,
|
||||
task.logprobs,
|
||||
)
|
||||
|
||||
|
||||
def init_model_cb(config: MetaReferenceImplConfig):
|
||||
llama = Llama.build(config)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
class LlamaModelParallelGenerator:
|
||||
"""
|
||||
This abstraction exists so
|
||||
- we can run model parallel code without needing to run the CLIs via torchrun
|
||||
- this also enables use model parallel code within a notebook context.
|
||||
|
||||
A Context Manager is used to ensure that the model parallel process is started and stopped
|
||||
correctly. This does make the ergonomics a little awkward, because it isn't immediately
|
||||
clear at the callsite why we need to use a context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MetaReferenceImplConfig):
|
||||
self.config = config
|
||||
self.model = resolve_model(self.config.model)
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
checkpoint_dir = model_checkpoint_dir(self.model)
|
||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
||||
def stop(self):
|
||||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
self.group = ModelParallelProcessGroup(
|
||||
self.model.hardware_requirements.gpu_count,
|
||||
init_model_cb=partial(init_model_cb, self.config),
|
||||
)
|
||||
self.group.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.group.stop()
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Message],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
) -> Generator:
|
||||
req_obj = InferenceArgs(
|
||||
messages=deepcopy(messages),
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
gen = self.group.run_inference(req_obj)
|
||||
yield from gen
|
265
llama_toolchain/inference/meta_reference/parallel_utils.py
Normal file
265
llama_toolchain/inference/meta_reference/parallel_utils.py
Normal file
|
@ -0,0 +1,265 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from typing import Callable, Generator
|
||||
|
||||
import torch
|
||||
|
||||
import zmq
|
||||
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_group,
|
||||
get_model_parallel_rank,
|
||||
get_model_parallel_src_rank,
|
||||
)
|
||||
|
||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
|
||||
|
||||
_END_SENTINEL = "__end_sentinel__"
|
||||
_CANCEL_SENTINEL = "__cancel_sentinel__"
|
||||
|
||||
|
||||
def mp_rank_0() -> bool:
|
||||
return get_model_parallel_rank() == 0
|
||||
|
||||
|
||||
def retrieve_requests(reply_socket_url: str):
|
||||
if mp_rank_0():
|
||||
context = zmq.Context()
|
||||
reply_socket = context.socket(zmq.ROUTER)
|
||||
reply_socket.connect(reply_socket_url)
|
||||
|
||||
while True:
|
||||
client_id, obj = maybe_get_work(reply_socket)
|
||||
if obj is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
reply_socket.send_multipart([client_id, pickle.dumps("YES READY")])
|
||||
break
|
||||
|
||||
def send_obj(obj):
|
||||
reply_socket.send_multipart([client_id, pickle.dumps(obj)])
|
||||
|
||||
while True:
|
||||
tasks = [None]
|
||||
if mp_rank_0():
|
||||
client_id, task = maybe_get_work(reply_socket)
|
||||
# there is still an unknown unclean GeneratorExit happening resulting in a
|
||||
# cancel sentinel getting queued _after_ we have finished sending everything :/
|
||||
# kind of a hack this is :/
|
||||
if task != _CANCEL_SENTINEL:
|
||||
tasks = [task]
|
||||
|
||||
torch.distributed.broadcast_object_list(
|
||||
tasks,
|
||||
src=get_model_parallel_src_rank(),
|
||||
group=get_model_parallel_group(),
|
||||
)
|
||||
|
||||
task = tasks[0]
|
||||
if task is None:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
try:
|
||||
out = yield task
|
||||
if out is None:
|
||||
break
|
||||
|
||||
for obj in out:
|
||||
updates = [None]
|
||||
if mp_rank_0():
|
||||
_, update = maybe_get_work(reply_socket)
|
||||
if update == _CANCEL_SENTINEL:
|
||||
updates = [update]
|
||||
else:
|
||||
# only send the update if it's not cancelled otherwise the object sits in the socket
|
||||
# and gets pulled in the next request lol
|
||||
send_obj(obj)
|
||||
|
||||
torch.distributed.broadcast_object_list(
|
||||
updates,
|
||||
src=get_model_parallel_src_rank(),
|
||||
group=get_model_parallel_group(),
|
||||
)
|
||||
if updates[0] == _CANCEL_SENTINEL:
|
||||
print("quitting generation loop because request was cancelled")
|
||||
break
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj(_END_SENTINEL)
|
||||
except Exception as e:
|
||||
print(f"[debug] got exception {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
if mp_rank_0():
|
||||
send_obj(e)
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj("DONE")
|
||||
|
||||
|
||||
def maybe_get_work(sock: zmq.Socket):
|
||||
message = None
|
||||
client_id = None
|
||||
try:
|
||||
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
|
||||
message = pickle.loads(obj)
|
||||
except zmq.ZMQError as e:
|
||||
if e.errno != zmq.EAGAIN:
|
||||
raise e
|
||||
|
||||
return client_id, message
|
||||
|
||||
|
||||
def worker_process_entrypoint(
|
||||
reply_socket_url: str,
|
||||
init_model_cb: Callable,
|
||||
) -> None:
|
||||
model = init_model_cb()
|
||||
torch.distributed.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# run the requests co-routine which retrieves requests from the socket
|
||||
# and sends responses (we provide) back to the caller
|
||||
req_gen = retrieve_requests(reply_socket_url)
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
task = req_gen.send(result)
|
||||
if isinstance(task, str) and task == _END_SENTINEL:
|
||||
break
|
||||
|
||||
result = model(task)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
print("[debug] worker process done")
|
||||
|
||||
|
||||
def launch_dist_group(
|
||||
reply_socket_url: str,
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
id = uuid.uuid4().hex
|
||||
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
||||
launch_config = LaunchConfig(
|
||||
max_nodes=1,
|
||||
min_nodes=1,
|
||||
nproc_per_node=model_parallel_size,
|
||||
start_method="fork",
|
||||
rdzv_backend="c10d",
|
||||
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
|
||||
rdzv_configs={"store_type": "file", "timeout": 90},
|
||||
max_restarts=0,
|
||||
monitor_interval=1,
|
||||
run_id=str(uuid.uuid4()),
|
||||
)
|
||||
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
|
||||
reply_socket_url,
|
||||
init_model_cb,
|
||||
)
|
||||
|
||||
|
||||
def start_model_parallel_process(
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
context = zmq.Context()
|
||||
request_socket = context.socket(zmq.DEALER)
|
||||
|
||||
# Binding the request socket to a random port
|
||||
request_socket.bind("tcp://127.0.0.1:0")
|
||||
|
||||
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
||||
|
||||
ctx = multiprocessing.get_context("fork")
|
||||
process = ctx.Process(
|
||||
target=launch_dist_group,
|
||||
args=(
|
||||
main_process_url,
|
||||
model_parallel_size,
|
||||
init_model_cb,
|
||||
),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
process.start()
|
||||
|
||||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||
|
||||
request_socket.send_pyobj("READY?")
|
||||
response = request_socket.recv_pyobj()
|
||||
print(f"Finished model load {response}")
|
||||
|
||||
return request_socket, process
|
||||
|
||||
|
||||
class ModelParallelProcessGroup:
|
||||
def __init__(
|
||||
self,
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_parallel_size = model_parallel_size
|
||||
self.init_model_cb = init_model_cb
|
||||
self.started = False
|
||||
self.running = False
|
||||
|
||||
def start(self):
|
||||
assert not self.started, "process group already started"
|
||||
self.request_socket, self.process = start_model_parallel_process(
|
||||
self.model_parallel_size,
|
||||
self.init_model_cb,
|
||||
)
|
||||
self.started = True
|
||||
|
||||
def stop(self):
|
||||
assert self.started, "process group not started"
|
||||
if self.process.is_alive():
|
||||
self.request_socket.send_pyobj(_END_SENTINEL, zmq.NOBLOCK)
|
||||
self.process.join()
|
||||
self.started = False
|
||||
|
||||
def run_inference(self, request) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
self.running = True
|
||||
self.request_socket.send_pyobj(request)
|
||||
try:
|
||||
while True:
|
||||
obj = self.request_socket.recv_pyobj()
|
||||
if obj == _END_SENTINEL:
|
||||
break
|
||||
|
||||
if isinstance(obj, Exception):
|
||||
print(f"[debug] got exception {obj}")
|
||||
raise obj
|
||||
|
||||
yield obj
|
||||
except GeneratorExit as e:
|
||||
self.request_socket.send_pyobj(_CANCEL_SENTINEL)
|
||||
while True:
|
||||
obj = self.request_socket.recv_pyobj()
|
||||
if obj == _END_SENTINEL:
|
||||
break
|
||||
finally:
|
||||
self.running = False
|
Loading…
Add table
Add a link
Reference in a new issue