mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 03:29:59 +00:00
feat: add batch inference API to llama stack inference
This commit is contained in:
parent
ed58a94b30
commit
0cfb2e2473
24 changed files with 1041 additions and 377 deletions
|
|
@ -9,6 +9,7 @@ import logging
|
|||
import os
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
|
@ -17,6 +18,8 @@ from llama_stack.apis.common.content_types import (
|
|||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
|
|
@ -38,6 +41,7 @@ from llama_stack.apis.inference import (
|
|||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
|
|
@ -65,7 +69,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generators import Llama3Generator, Llama4Generator
|
||||
from .generators import LlamaGenerator
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -74,12 +78,8 @@ log = logging.getLogger(__name__)
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
|
||||
return Llama3Generator(config, model_id, llama_model)
|
||||
|
||||
|
||||
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
|
||||
return Llama4Generator(config, model_id, llama_model)
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
|
|
@ -139,24 +139,12 @@ class MetaReferenceInferenceImpl(
|
|||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
|
||||
if llama_model.model_family in {
|
||||
ModelFamily.llama3,
|
||||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
}:
|
||||
builder_fn = llama3_builder_fn
|
||||
elif llama_model.model_family == ModelFamily.llama4:
|
||||
builder_fn = llama4_builder_fn
|
||||
else:
|
||||
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
|
||||
|
||||
builder_params = [self.config, model_id, llama_model]
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||
builder_fn=builder_fn,
|
||||
builder_fn=llama_builder_fn,
|
||||
builder_params=builder_params,
|
||||
formatter=(
|
||||
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
||||
|
|
@ -166,11 +154,24 @@ class MetaReferenceInferenceImpl(
|
|||
)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = builder_fn(*builder_params)
|
||||
self.generator = llama_builder_fn(*builder_params)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
print("Warming up...")
|
||||
await self.completion(
|
||||
model_id=model_id,
|
||||
content="Hello, world!",
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
)
|
||||
await self.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[UserMessage(content="Hi how are you?")],
|
||||
sampling_params=SamplingParams(max_tokens=20),
|
||||
)
|
||||
print("Warmed up!")
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
|
|
@ -208,7 +209,43 @@ class MetaReferenceInferenceImpl(
|
|||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
results = await self._nonstream_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
content_batch = [
|
||||
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
||||
]
|
||||
|
||||
request_batch = []
|
||||
for content in content_batch:
|
||||
request = CompletionRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
|
||||
results = await self._nonstream_completion(request_batch)
|
||||
return BatchCompletionResponse(batch=results)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
|
@ -253,37 +290,54 @@ class MetaReferenceInferenceImpl(
|
|||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: List[int] = []
|
||||
logprobs: List[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
states = [ItemState() for _ in request_batch]
|
||||
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
results = []
|
||||
for token_results in self.generator.completion(request_batch):
|
||||
for result in token_results:
|
||||
idx = result.batch_idx
|
||||
state = states[idx]
|
||||
if state.finished or result.ignore_token:
|
||||
continue
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
state.stop_reason = StopReason.end_of_turn
|
||||
elif result.token == tokenizer.eom_id:
|
||||
state.stop_reason = StopReason.end_of_message
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
for state in states:
|
||||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
|
||||
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
||||
tokens = tokens[:-1]
|
||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||
return CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
||||
state.tokens = state.tokens[:-1]
|
||||
content = self.generator.formatter.tokenizer.decode(state.tokens)
|
||||
results.append(
|
||||
CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=state.stop_reason,
|
||||
logprobs=state.logprobs if first_request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
|
|
@ -318,7 +372,7 @@ class MetaReferenceInferenceImpl(
|
|||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
|
|
@ -334,44 +388,110 @@ class MetaReferenceInferenceImpl(
|
|||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
results = await self._nonstream_chat_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request_batch = []
|
||||
for messages in messages_batch:
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
results = await self._nonstream_chat_completion(request_batch)
|
||||
return BatchChatCompletionResponse(batch=results)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request_batch: List[ChatCompletionRequest]
|
||||
) -> List[ChatCompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: List[int] = []
|
||||
logprobs: List[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
states = [ItemState() for _ in request_batch]
|
||||
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, "cyan", end="")
|
||||
for token_results in self.generator.chat_completion(request_batch):
|
||||
first = token_results[0]
|
||||
if not first.finished:
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
||||
cprint(first.text, "cyan", end="")
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{first.token}>", "magenta", end="")
|
||||
|
||||
tokens.append(token_result.token)
|
||||
for result in token_results:
|
||||
idx = result.batch_idx
|
||||
state = states[idx]
|
||||
if state.finished or result.ignore_token:
|
||||
continue
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
state.stop_reason = StopReason.end_of_turn
|
||||
elif result.token == tokenizer.eom_id:
|
||||
state.stop_reason = StopReason.end_of_message
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
results = []
|
||||
for state in states:
|
||||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
||||
results.append(
|
||||
ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=state.logprobs if first_request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
return results
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
|
|
@ -398,6 +518,22 @@ class MetaReferenceInferenceImpl(
|
|||
for token_result in self.generator.chat_completion(request):
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, "cyan", end="")
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{token_result.token}>", "magenta", end="")
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
|
||||
tokens.append(token_result.token)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue