This commit is contained in:
Eric Huang 2025-04-24 12:03:55 -07:00
parent a5d6ab16b2
commit 693c709c27
2 changed files with 103 additions and 34 deletions

View file

@ -8,9 +8,6 @@ import asyncio
import os
from typing import AsyncGenerator, List, Optional, Union
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
@ -55,8 +52,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
@ -68,6 +65,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
convert_request_to_raw,
)
from pydantic import BaseModel
from termcolor import cprint
from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator
@ -78,7 +78,9 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
def llama_builder_fn(
config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model
) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
@ -143,7 +145,8 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
model_parallel_size=self.config.model_parallel_size
or llama_model.pth_file_count,
builder_fn=llama_builder_fn,
builder_params=builder_params,
formatter=(
@ -178,7 +181,9 @@ class MetaReferenceInferenceImpl(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
async def completion(
self,
@ -227,7 +232,8 @@ class MetaReferenceInferenceImpl(
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
augment_content_with_response_format_prompt(response_format, content)
for content in content_batch
]
request_batch = []
@ -253,7 +259,8 @@ class MetaReferenceInferenceImpl(
def impl():
stop_reason = None
for token_result in self.generator.completion(request):
for token_results in self.generator.completion([request]):
token_result = token_results[0]
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
@ -268,7 +275,13 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
yield CompletionResponseStreamChunk(
delta=text,
@ -290,7 +303,9 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
async def _nonstream_completion(
self, request_batch: List[CompletionRequest]
) -> List[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
@ -314,7 +329,11 @@ class MetaReferenceInferenceImpl(
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
state.logprobs.append(
TokenLogProbs(
logprobs_by_token={result.text: result.logprobs[0]}
)
)
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
@ -377,7 +396,9 @@ class MetaReferenceInferenceImpl(
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
@ -422,7 +443,9 @@ class MetaReferenceInferenceImpl(
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
@ -466,7 +489,11 @@ class MetaReferenceInferenceImpl(
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
state.logprobs.append(
TokenLogProbs(
logprobs_by_token={result.text: result.logprobs[0]}
)
)
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
@ -479,7 +506,9 @@ class MetaReferenceInferenceImpl(
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
raw_message = self.generator.formatter.decode_assistant_message(
state.tokens, state.stop_reason
)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
@ -499,7 +528,9 @@ class MetaReferenceInferenceImpl(
else:
return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
@ -534,7 +565,13 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
tokens.append(token_result.token)
@ -572,7 +609,13 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -585,7 +628,9 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:

View file

@ -28,15 +28,15 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
get_model_parallel_src_rank,
)
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
log = logging.getLogger(__name__)
@ -52,33 +52,51 @@ class ProcessingMessageName(str, Enum):
class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
type: Literal[ProcessingMessageName.ready_request] = (
ProcessingMessageName.ready_request
)
class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
type: Literal[ProcessingMessageName.ready_response] = (
ProcessingMessageName.ready_response
)
class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
type: Literal[ProcessingMessageName.end_sentinel] = (
ProcessingMessageName.end_sentinel
)
class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
type: Literal[ProcessingMessageName.cancel_sentinel] = (
ProcessingMessageName.cancel_sentinel
)
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: Tuple[
str,
List[CompletionRequestWithRawContent]
| List[ChatCompletionRequestWithRawContent],
]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
type: Literal[ProcessingMessageName.task_response] = (
ProcessingMessageName.task_response
)
result: List[GenerationResult]
class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
type: Literal[ProcessingMessageName.exception_response] = (
ProcessingMessageName.exception_response
)
error: str
@ -172,7 +190,9 @@ def retrieve_requests(reply_socket_url: str):
group=get_model_parallel_group(),
)
if isinstance(updates[0], CancelSentinel):
log.info("quitting generation loop because request was cancelled")
log.info(
"quitting generation loop because request was cancelled"
)
break
if mp_rank_0():
@ -234,7 +254,7 @@ def worker_process_entrypoint(
if isinstance(task, EndSentinel):
break
assert isinstance(task, TaskRequest)
assert isinstance(task, TaskRequest), task
result = model(task.task)
except StopIteration:
break
@ -331,7 +351,11 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
req: Tuple[
str,
List[CompletionRequestWithRawContent]
| List[ChatCompletionRequestWithRawContent],
],
) -> Generator:
assert not self.running, "inference already running"