llama-stack-mirror/llama_stack/providers/inline/inference/meta_reference/inference.py
Botao Chen 36b4fe02cc
[4/n][torchtune integration] support lazy load model during inference (#620)
## What does this PR do?
In this PR, we refactor the meta reference inference logic to support 
- load the model during registering model instead of during spinning up
server
- support inference finetuned model checkpoint on top of native llama
model

## Why need these changes
To solve the existing pain points that 
- user cannot lazy load the model and hot switch the inference
checkpoint after spinning up the server
- this blocks us doing inference and eval on the same sever for a
finetuned checkpoint after post training
- user cannot do inference on a finetuned checkpoint on top of native
llama models

## Expect user experience change
- The inference model won't be loaded when spinning up server. Instead,
it will be loaded during register model. If user add the model as models
resource in run.yaml, it will be registered and loaded automatically
when starting server. There is an optional flag 'skip_initialize' in
model metadata to skip model loading during registration.
- There is an optional flag 'llama_model' in model metadata to identify
the base model of the Model class for validation and initialize model
arch. model identifier no longer needs to be a native llama model
- the default inference model name updates from
'meta-llama/Llama-3.2-3B-Instruct' to 'Llama3.2-3B-Instruct'
- It aligns with the checkpoint folder name after running 'llama model
download'
- It aligns with the descriptor name defined in llama-models SKU list
bf5b0c4fe7/models/datatypes.py (L95)


## test
run python llama_stack/scripts/distro_codegen.py


**run unit test**
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_text_inference.py
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_model_registration.py


**test post training experience**
on server side run: llama stack run
llama_stack/templates/experimental-post-training/run.yaml
server is spinning up without model loaded

<img width="812" alt="Screenshot 2024-12-17 at 1 24 50 PM"
src="https://github.com/user-attachments/assets/ce1f606b-3b6f-452f-b48e-b3761ffd90f3"
/>

on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 models register
Llama3.2-3B-Instruct
register model successfully and the model is loaded 
<img width="1111" alt="Screenshot 2024-12-17 at 1 26 30 PM"
src="https://github.com/user-attachments/assets/56e02131-cf7d-4de5-8f63-fbdcb8c55c26"
/>


<img width="1541" alt="Screenshot 2024-12-17 at 1 26 09 PM"
src="https://github.com/user-attachments/assets/a83255a1-20f5-40a2-af51-55641410a115"
/>

if add "skip_initialize" in metadata, model is registered but isn't
loaded

on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"

Inference the model succesfully
<img width="1121" alt="Screenshot 2024-12-17 at 1 27 33 PM"
src="https://github.com/user-attachments/assets/8e708545-3fe7-4a73-8754-1470fa5f1e75"
/>

**test inference experience**
run: llama stack run llama_stack/templates/meta-reference-gpu/run.yaml
model is loaded since the model is in resouce list in run.yaml 
<img width="1537" alt="Screenshot 2024-12-17 at 1 30 19 PM"
src="https://github.com/user-attachments/assets/5c8af817-66eb-43f8-bf4c-f5e24b0a12c6"
/>

on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"
inference successfully 
<img width="1123" alt="Screenshot 2024-12-17 at 1 31 08 PM"
src="https://github.com/user-attachments/assets/471809aa-c65e-46dc-a37e-7094fb857f97"
/>



## inference on a finetuned model
**register a finetuned model that finetuned by post training api
(torchtune)**
- the model is registered and loaded successfully 
- the model is shown up in the model list 
<img width="974" alt="Screenshot 2024-12-18 at 3 56 33 PM"
src="https://github.com/user-attachments/assets/2994b4f5-4fa9-40c6-acc6-4b971479f3e2"
/>

**run inference**

<img width="977" alt="Screenshot 2024-12-18 at 3 57 59 PM"
src="https://github.com/user-attachments/assets/d117abbc-b2a0-41d8-a028-1a13128787b2"
/>
2024-12-18 16:30:53 -08:00

465 lines
16 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
convert_request_to_raw,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
self.model_id = None
self.llama_model = None
async def initialize(self) -> None:
pass
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
self.config, model_id, llama_model
)
self.generator.start()
else:
self.generator = Llama.build(self.config, model_id, llama_model)
self.model_id = model_id
self.llama_model = llama_model
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: Model) -> Model:
llama_model = (
resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata
else resolve_model(model.identifier)
)
if llama_model is None:
raise ValueError(
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
)
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
llama_model.descriptor(),
llama_model.core_model_id.value,
)
],
)
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
if "skip_load" in model.metadata and model.metadata["skip_load"]:
return model
await self.load_model(model.identifier, llama_model)
return model
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content = augment_content_with_response_format_prompt(response_format, content)
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
if request.stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
def impl():
stop_reason = None
for token_result in self.generator.completion(request):
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
def impl():
tokens = []
logprobs = []
stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.token in tokenizer.stop_tokens:
# not quite right semantically
stop_reason = StopReason.end_of_turn
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
content = self.generator.formatter.tokenizer.decode(tokens)
return CompletionResponse(
content=content,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
def impl():
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x