feat: add batch inference API to llama stack inference

This commit is contained in:
Ashwin Bharambe 2025-04-08 13:50:52 -07:00
parent ed58a94b30
commit 0cfb2e2473
24 changed files with 1041 additions and 377 deletions

View file

@ -9,6 +9,7 @@ import logging
import os
from typing import AsyncGenerator, List, Optional, Union
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
@ -17,6 +18,8 @@ from llama_stack.apis.common.content_types import (
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
@ -38,6 +41,7 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
@ -65,7 +69,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import MetaReferenceInferenceConfig
from .generators import Llama3Generator, Llama4Generator
from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
@ -74,12 +78,8 @@ log = logging.getLogger(__name__)
SEMAPHORE = asyncio.Semaphore(1)
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
return Llama3Generator(config, model_id, llama_model)
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
return Llama4Generator(config, model_id, llama_model)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
class MetaReferenceInferenceImpl(
@ -139,24 +139,12 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if llama_model.model_family in {
ModelFamily.llama3,
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
}:
builder_fn = llama3_builder_fn
elif llama_model.model_family == ModelFamily.llama4:
builder_fn = llama4_builder_fn
else:
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
builder_params = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
builder_fn=builder_fn,
builder_fn=llama_builder_fn,
builder_params=builder_params,
formatter=(
Llama4ChatFormat(Llama4Tokenizer.get_instance())
@ -166,11 +154,24 @@ class MetaReferenceInferenceImpl(
)
self.generator.start()
else:
self.generator = builder_fn(*builder_params)
self.generator = llama_builder_fn(*builder_params)
self.model_id = model_id
self.llama_model = llama_model
print("Warming up...")
await self.completion(
model_id=model_id,
content="Hello, world!",
sampling_params=SamplingParams(max_tokens=10),
)
await self.chat_completion(
model_id=model_id,
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
)
print("Warmed up!")
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
@ -208,7 +209,43 @@ class MetaReferenceInferenceImpl(
if request.stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
results = await self._nonstream_completion([request])
return results[0]
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
]
request_batch = []
for content in content_batch:
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
@ -253,37 +290,54 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
tokens = []
logprobs = []
stop_reason = None
states = [ItemState() for _ in request_batch]
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
results = []
for token_results in self.generator.completion(request_batch):
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
if request.logprobs:
assert len(token_result.logprobs) == 1
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
tokens = tokens[:-1]
content = self.generator.formatter.tokenizer.decode(tokens)
return CompletionResponse(
content=content,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
state.tokens = state.tokens[:-1]
content = self.generator.formatter.tokenizer.decode(state.tokens)
results.append(
CompletionResponse(
content=content,
stop_reason=state.stop_reason,
logprobs=state.logprobs if first_request.logprobs else None,
)
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
@ -318,7 +372,7 @@ class MetaReferenceInferenceImpl(
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
@ -334,44 +388,110 @@ class MetaReferenceInferenceImpl(
if request.stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
results = await self._nonstream_chat_completion([request])
return results[0]
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: List[ChatCompletionRequest]
) -> List[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
tokens = []
logprobs = []
stop_reason = None
states = [ItemState() for _ in request_batch]
for token_result in self.generator.chat_completion(request):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
for token_results in self.generator.chat_completion(request_batch):
first = token_results[0]
if not first.finished:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", "magenta", end="")
tokens.append(token_result.token)
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
if request.logprobs:
assert len(token_result.logprobs) == 1
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
)
)
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=logprobs if request.logprobs else None,
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
@ -398,6 +518,22 @@ class MetaReferenceInferenceImpl(
for token_result in self.generator.chat_completion(request):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", "magenta", end="")
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)