mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
chore(package): migrate to src/ layout (#3920)
Migrates package structure to src/ layout following Python packaging best practices. All code moved from `llama_stack/` to `src/llama_stack/`. Public API unchanged - imports remain `import llama_stack.*`. Updated build configs, pre-commit hooks, scripts, and GitHub workflows accordingly. All hooks pass, package builds cleanly. **Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
parent
98a5047f9d
commit
471b1b248b
791 changed files with 2983 additions and 456 deletions
5
src/llama_stack/providers/inline/inference/__init__.py
Normal file
5
src/llama_stack/providers/inline/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceInferenceConfig,
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
|
||||
|
||||
def model_checkpoint_dir(model_id) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model_id))
|
||||
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
|
||||
f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
class MetaReferenceInferenceConfig(BaseModel):
|
||||
# this is a placeholder to indicate inference model id
|
||||
# the actual inference model id is dtermined by the moddel id in the request
|
||||
# Note: you need to register the model before using it for inference
|
||||
# models in the resouce list in the run.yaml config will be registered automatically
|
||||
model: str | None = None
|
||||
torch_seed: int | None = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
model_parallel_size: int | None = None
|
||||
|
||||
# when this is False, we assume that the distributed process group is setup by someone
|
||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||
# (including our testing code) who might be using llama-stack as a library.
|
||||
create_distributed_process_group: bool = True
|
||||
|
||||
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
|
||||
# can override by specifying the directory explicitly
|
||||
checkpoint_dir: str | None = None
|
||||
|
||||
quantization: QuantizationConfig | None = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
descriptors = [m.descriptor() for m in permitted_models]
|
||||
repos = [m.huggingface_repo for m in permitted_models if m.huggingface_repo is not None]
|
||||
if model not in (descriptors + repos):
|
||||
model_list = "\n\t".join(repos)
|
||||
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:=null}",
|
||||
quantization_type: str = "${env.QUANTIZATION_TYPE:=bf16}",
|
||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:=0}",
|
||||
max_batch_size: str = "${env.MAX_BATCH_SIZE:=1}",
|
||||
max_seq_len: str = "${env.MAX_SEQ_LEN:=4096}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model": model,
|
||||
"checkpoint_dir": checkpoint_dir,
|
||||
"quantization": {
|
||||
"type": quantization_type,
|
||||
},
|
||||
"model_parallel_size": model_parallel_size,
|
||||
"max_batch_size": max_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
|
@ -0,0 +1,211 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
GreedySamplingStrategy,
|
||||
JsonSchemaResponseFormat,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode
|
||||
from llama_stack.models.llama.llama3.generation import Llama3
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
get_default_tool_prompt_format,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .inference import resolve_model
|
||||
|
||||
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: torch.Tensor | None = None
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
token_sequence = tokens[0, :].tolist()
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: ResponseFormat | None,
|
||||
) -> Optional["LogitsProcessor"]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if not isinstance(response_format, JsonSchemaResponseFormat):
|
||||
raise ValueError(f"Unsupported response format type {response_format.type}")
|
||||
|
||||
parser = JsonSchemaParser(response_format.json_schema)
|
||||
data = TokenEnforcerTokenizerData(
|
||||
_build_regular_tokens_list(tokenizer, vocab_size),
|
||||
tokenizer.decode,
|
||||
tokenizer.stop_tokens,
|
||||
)
|
||||
token_enforcer = TokenEnforcer(data, parser)
|
||||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]:
|
||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||
regular_tokens = []
|
||||
|
||||
special_token_ids = set(tokenizer.special_tokens.values())
|
||||
for token_idx in range(vocab_size):
|
||||
if token_idx in special_token_ids:
|
||||
continue
|
||||
|
||||
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
||||
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
||||
decoded_regular = tokenizer.decode([token_idx])
|
||||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def _infer_sampling_params(sampling_params: SamplingParams):
|
||||
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
||||
temperature = 0.0
|
||||
top_p = 1.0
|
||||
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
temperature = sampling_params.strategy.temperature or 1.0
|
||||
top_p = sampling_params.strategy.top_p or 1.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
||||
return temperature, top_p
|
||||
|
||||
|
||||
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
||||
tool_config = request.tool_config
|
||||
if tool_config is not None and tool_config.tool_prompt_format is not None:
|
||||
return tool_config.tool_prompt_format
|
||||
else:
|
||||
return get_default_tool_prompt_format(request.model)
|
||||
|
||||
|
||||
class LlamaGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if config.quantization:
|
||||
if config.quantization.type == "fp8_mixed":
|
||||
quantization_mode = QuantizationMode.fp8_mixed
|
||||
elif config.quantization.type == "int4_mixed":
|
||||
quantization_mode = QuantizationMode.int4_mixed
|
||||
elif config.quantization.type == "bf16":
|
||||
quantization_mode = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
quantization_mode = None
|
||||
|
||||
cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
|
||||
self.inner_generator = cls.build(
|
||||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
|
||||
self.tokenizer = self.inner_generator.tokenizer
|
||||
self.args = self.inner_generator.args
|
||||
self.formatter = self.inner_generator.formatter
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(first_request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[
|
||||
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
||||
for request in request_batch
|
||||
],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(first_request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generators import LlamaGenerator
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = get_logger(__name__, category="inference")
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
self.model_id = None
|
||||
self.llama_model = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return None
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
llama_model = (
|
||||
resolve_model(model.metadata["llama_model"])
|
||||
if "llama_model" in model.metadata
|
||||
else resolve_model(model.identifier)
|
||||
)
|
||||
if llama_model is None:
|
||||
raise ValueError(
|
||||
"Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
|
||||
)
|
||||
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_hf_repo_model_entry(
|
||||
llama_model.descriptor(),
|
||||
llama_model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
model = await self.model_registry_helper.register_model(model)
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
||||
# TODO: what is this?! you can't really specify skipping via model metadata
|
||||
# kill this madness
|
||||
if "skip_load" in model.metadata and model.metadata["skip_load"]:
|
||||
return model
|
||||
|
||||
await self.load_model(model.identifier, llama_model)
|
||||
return model
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
|
||||
builder_params = [self.config, model_id, llama_model]
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||
builder_fn=llama_builder_fn,
|
||||
builder_params=builder_params,
|
||||
formatter=(
|
||||
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
||||
if llama_model.model_family == ModelFamily.llama4
|
||||
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
|
||||
),
|
||||
)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = llama_builder_fn(*builder_params)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
log.info("Warming up...")
|
||||
await self.openai_chat_completion(
|
||||
model=model_id,
|
||||
messages=[{"role": "user", "content": "Hi how are you?"}],
|
||||
max_tokens=20,
|
||||
)
|
||||
log.info("Warmed up!")
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(self, llama):
|
||||
self.llama = llama
|
||||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, task: Any):
|
||||
if task[0] == "chat_completion":
|
||||
return self.llama.chat_completion(task[1])
|
||||
else:
|
||||
raise ValueError(f"Unexpected task type {task[0]}")
|
||||
|
||||
|
||||
def init_model_cb(
|
||||
builder_fn: Callable,
|
||||
params: list[Any],
|
||||
):
|
||||
llama = builder_fn(*params)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
class LlamaModelParallelGenerator:
|
||||
"""
|
||||
This abstraction exists so
|
||||
- we can run model parallel code without needing to run the CLIs via torchrun
|
||||
- this also enables use model parallel code within a notebook context.
|
||||
|
||||
A Context Manager is used to ensure that the model parallel process is started and stopped
|
||||
correctly. This does make the ergonomics a little awkward, because it isn't immediately
|
||||
clear at the callsite why we need to use a context manager.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_parallel_size: int,
|
||||
builder_fn: Callable,
|
||||
builder_params: list[Any],
|
||||
formatter: Llama3ChatFormat | Llama4ChatFormat,
|
||||
):
|
||||
self.model_parallel_size = model_parallel_size
|
||||
self.builder_fn = builder_fn
|
||||
self.builder_params = builder_params
|
||||
self.formatter = formatter
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
||||
def stop(self):
|
||||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
self.group = ModelParallelProcessGroup(
|
||||
self.model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.builder_fn, self.builder_params),
|
||||
)
|
||||
self.group.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.group.stop()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("completion", req_obj))
|
||||
yield from gen
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("chat_completion", req_obj))
|
||||
yield from gen
|
||||
|
|
@ -0,0 +1,363 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_group,
|
||||
get_model_parallel_rank,
|
||||
get_model_parallel_src_rank,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class ProcessingMessageName(str, Enum):
|
||||
ready_request = "ready_request"
|
||||
ready_response = "ready_response"
|
||||
end_sentinel = "end_sentinel"
|
||||
cancel_sentinel = "cancel_sentinel"
|
||||
task_request = "task_request"
|
||||
task_response = "task_response"
|
||||
exception_response = "exception_response"
|
||||
|
||||
|
||||
class ReadyRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
|
||||
|
||||
|
||||
class ReadyResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
|
||||
|
||||
|
||||
class EndSentinel(BaseModel):
|
||||
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
|
||||
|
||||
|
||||
class CancelSentinel(BaseModel):
|
||||
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||
task: tuple[
|
||||
str,
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||
result: list[GenerationResult]
|
||||
|
||||
|
||||
class ExceptionResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
|
||||
error: str
|
||||
|
||||
|
||||
ProcessingMessage = (
|
||||
ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse
|
||||
)
|
||||
|
||||
|
||||
class ProcessingMessageWrapper(BaseModel):
|
||||
payload: Annotated[
|
||||
ProcessingMessage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def mp_rank_0() -> bool:
|
||||
return bool(get_model_parallel_rank() == 0)
|
||||
|
||||
|
||||
def encode_msg(msg: ProcessingMessage) -> bytes:
|
||||
return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
|
||||
|
||||
|
||||
def retrieve_requests(reply_socket_url: str):
|
||||
if mp_rank_0():
|
||||
context = zmq.Context()
|
||||
reply_socket = context.socket(zmq.ROUTER)
|
||||
reply_socket.connect(reply_socket_url)
|
||||
|
||||
while True:
|
||||
client_id, obj = maybe_get_work(reply_socket)
|
||||
if obj is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
ready_response = ReadyResponse()
|
||||
reply_socket.send_multipart([client_id, encode_msg(ready_response)])
|
||||
break
|
||||
|
||||
def send_obj(obj: ProcessingMessage):
|
||||
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
||||
|
||||
while True:
|
||||
tasks: list[ProcessingMessage | None] = [None]
|
||||
if mp_rank_0():
|
||||
client_id, maybe_task_json = maybe_get_work(reply_socket)
|
||||
if maybe_task_json is not None:
|
||||
task = maybe_parse_message(maybe_task_json)
|
||||
# there is still an unknown unclean GeneratorExit happening resulting in a
|
||||
# cancel sentinel getting queued _after_ we have finished sending everything :/
|
||||
# kind of a hack this is :/
|
||||
if task is not None and not isinstance(task, CancelSentinel):
|
||||
tasks = [task]
|
||||
|
||||
torch.distributed.broadcast_object_list(
|
||||
tasks,
|
||||
src=get_model_parallel_src_rank(),
|
||||
group=get_model_parallel_group(),
|
||||
)
|
||||
|
||||
task = tasks[0]
|
||||
if task is None:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
try:
|
||||
out = yield task
|
||||
if out is None:
|
||||
break
|
||||
|
||||
for obj in out:
|
||||
updates: list[ProcessingMessage | None] = [None]
|
||||
if mp_rank_0():
|
||||
_, update_json = maybe_get_work(reply_socket)
|
||||
update = maybe_parse_message(update_json)
|
||||
if isinstance(update, CancelSentinel):
|
||||
updates = [update]
|
||||
else:
|
||||
# only send the update if it's not cancelled otherwise the object sits in the socket
|
||||
# and gets pulled in the next request lol
|
||||
send_obj(TaskResponse(result=obj))
|
||||
|
||||
torch.distributed.broadcast_object_list(
|
||||
updates,
|
||||
src=get_model_parallel_src_rank(),
|
||||
group=get_model_parallel_group(),
|
||||
)
|
||||
if isinstance(updates[0], CancelSentinel):
|
||||
log.info("quitting generation loop because request was cancelled")
|
||||
break
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj(EndSentinel())
|
||||
except Exception as e:
|
||||
log.exception("exception in generation loop")
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj(ExceptionResponse(error=str(e)))
|
||||
|
||||
if mp_rank_0():
|
||||
send_obj(EndSentinel())
|
||||
|
||||
|
||||
def maybe_get_work(sock: zmq.Socket):
|
||||
message = None
|
||||
client_id = None
|
||||
try:
|
||||
client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
|
||||
message = obj.decode("utf-8")
|
||||
except zmq.ZMQError as e:
|
||||
if e.errno != zmq.EAGAIN:
|
||||
raise e
|
||||
|
||||
return client_id, message
|
||||
|
||||
|
||||
def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None:
|
||||
if maybe_json is None:
|
||||
return None
|
||||
try:
|
||||
return parse_message(maybe_json)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def parse_message(json_str: str) -> ProcessingMessage:
|
||||
data = json.loads(json_str)
|
||||
return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
|
||||
|
||||
|
||||
def worker_process_entrypoint(
|
||||
reply_socket_url: str,
|
||||
init_model_cb: Callable,
|
||||
) -> None:
|
||||
model = init_model_cb()
|
||||
torch.distributed.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# run the requests co-routine which retrieves requests from the socket
|
||||
# and sends responses (we provide) back to the caller
|
||||
req_gen = retrieve_requests(reply_socket_url)
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
task = req_gen.send(result)
|
||||
if isinstance(task, EndSentinel):
|
||||
break
|
||||
|
||||
assert isinstance(task, TaskRequest), task
|
||||
result = model(task.task)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
log.info("[debug] worker process done")
|
||||
|
||||
|
||||
def launch_dist_group(
|
||||
reply_socket_url: str,
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
|
||||
launch_config = LaunchConfig(
|
||||
max_nodes=1,
|
||||
min_nodes=1,
|
||||
nproc_per_node=model_parallel_size,
|
||||
start_method="fork",
|
||||
rdzv_backend="c10d",
|
||||
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
|
||||
rdzv_configs={"store_type": "file", "timeout": 90},
|
||||
max_restarts=0,
|
||||
monitor_interval=1,
|
||||
run_id=str(uuid.uuid4()),
|
||||
)
|
||||
elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
|
||||
reply_socket_url,
|
||||
init_model_cb,
|
||||
)
|
||||
|
||||
|
||||
def start_model_parallel_process(
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
context = zmq.Context()
|
||||
request_socket = context.socket(zmq.DEALER)
|
||||
|
||||
# Binding the request socket to a random port
|
||||
request_socket.bind("tcp://127.0.0.1:0")
|
||||
|
||||
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
||||
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
process = ctx.Process(
|
||||
target=launch_dist_group,
|
||||
args=(
|
||||
main_process_url,
|
||||
model_parallel_size,
|
||||
init_model_cb,
|
||||
),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
process.start()
|
||||
|
||||
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
|
||||
|
||||
request_socket.send(encode_msg(ReadyRequest()))
|
||||
_response = request_socket.recv()
|
||||
log.info("Loaded model...")
|
||||
|
||||
return request_socket, process
|
||||
|
||||
|
||||
class ModelParallelProcessGroup:
|
||||
def __init__(
|
||||
self,
|
||||
model_parallel_size: int,
|
||||
init_model_cb: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_parallel_size = model_parallel_size
|
||||
self.init_model_cb = init_model_cb
|
||||
self.started = False
|
||||
self.running = False
|
||||
|
||||
def start(self):
|
||||
assert not self.started, "process group already started"
|
||||
self.request_socket, self.process = start_model_parallel_process(
|
||||
self.model_parallel_size,
|
||||
self.init_model_cb,
|
||||
)
|
||||
self.started = True
|
||||
|
||||
def stop(self):
|
||||
assert self.started, "process group not started"
|
||||
if self.process.is_alive():
|
||||
self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
|
||||
self.process.join()
|
||||
self.started = False
|
||||
|
||||
def run_inference(
|
||||
self,
|
||||
req: tuple[
|
||||
str,
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
self.running = True
|
||||
try:
|
||||
self.request_socket.send(encode_msg(TaskRequest(task=req)))
|
||||
while True:
|
||||
obj_json = self.request_socket.recv()
|
||||
obj = parse_message(obj_json)
|
||||
|
||||
if isinstance(obj, EndSentinel):
|
||||
break
|
||||
|
||||
if isinstance(obj, ExceptionResponse):
|
||||
log.error(f"[debug] got exception {obj.error}")
|
||||
raise Exception(obj.error)
|
||||
|
||||
if isinstance(obj, TaskResponse):
|
||||
yield obj.result
|
||||
|
||||
except GeneratorExit:
|
||||
self.request_socket.send(encode_msg(CancelSentinel()))
|
||||
while True:
|
||||
obj_json = self.request_socket.send()
|
||||
obj = parse_message(obj_json)
|
||||
if isinstance(obj, EndSentinel):
|
||||
break
|
||||
finally:
|
||||
self.running = False
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
impl = SentenceTransformersInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SentenceTransformersInferenceConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
)
|
||||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return [
|
||||
Model(
|
||||
identifier="nomic-ai/nomic-embed-text-v1.5",
|
||||
provider_resource_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")
|
||||
Loading…
Add table
Add a link
Reference in a new issue