From 693c709c276cdb8c9492b9f7a0771a9c47ba21e7 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Thu, 24 Apr 2025 12:03:55 -0700 Subject: [PATCH] fixes --- .../inference/meta_reference/inference.py | 85 ++++++++++++++----- .../meta_reference/parallel_utils.py | 52 +++++++++--- 2 files changed, 103 insertions(+), 34 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 0e69c2e7e..f25fc77f7 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -8,9 +8,6 @@ import asyncio import os from typing import AsyncGenerator, List, Optional, Union -from pydantic import BaseModel -from termcolor import cprint - from llama_stack.apis.common.content_types import ( TextDelta, ToolCallDelta, @@ -55,8 +52,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, build_hf_repo_model_entry, + ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( OpenAIChatCompletionToLlamaStackMixin, @@ -68,6 +65,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( convert_request_to_raw, ) +from pydantic import BaseModel +from termcolor import cprint + from .config import MetaReferenceInferenceConfig from .generators import LlamaGenerator from .model_parallel import LlamaModelParallelGenerator @@ -78,7 +78,9 @@ log = get_logger(__name__, category="inference") SEMAPHORE = asyncio.Semaphore(1) -def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: +def llama_builder_fn( + config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model +) -> LlamaGenerator: return LlamaGenerator(config, model_id, llama_model) @@ -143,7 +145,8 @@ class MetaReferenceInferenceImpl( if self.config.create_distributed_process_group: self.generator = LlamaModelParallelGenerator( - model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count, + model_parallel_size=self.config.model_parallel_size + or llama_model.pth_file_count, builder_fn=llama_builder_fn, builder_params=builder_params, formatter=( @@ -178,7 +181,9 @@ class MetaReferenceInferenceImpl( "No avaible model yet, please register your requested model or add your model in the resouces first" ) elif request.model != self.model_id: - raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}") + raise RuntimeError( + f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}" + ) async def completion( self, @@ -227,7 +232,8 @@ class MetaReferenceInferenceImpl( assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" content_batch = [ - augment_content_with_response_format_prompt(response_format, content) for content in content_batch + augment_content_with_response_format_prompt(response_format, content) + for content in content_batch ] request_batch = [] @@ -253,7 +259,8 @@ class MetaReferenceInferenceImpl( def impl(): stop_reason = None - for token_result in self.generator.completion(request): + for token_results in self.generator.completion([request]): + token_result = token_results[0] if token_result.token == tokenizer.eot_id: stop_reason = StopReason.end_of_turn text = "" @@ -268,7 +275,13 @@ class MetaReferenceInferenceImpl( if request.logprobs: assert len(token_result.logprobs) == 1 - logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})] + logprobs = [ + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ] yield CompletionResponseStreamChunk( delta=text, @@ -290,7 +303,9 @@ class MetaReferenceInferenceImpl( for x in impl(): yield x - async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]: + async def _nonstream_completion( + self, request_batch: List[CompletionRequest] + ) -> List[CompletionResponse]: tokenizer = self.generator.formatter.tokenizer first_request = request_batch[0] @@ -314,7 +329,11 @@ class MetaReferenceInferenceImpl( state.finished = result.finished if first_request.logprobs: - state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) + state.logprobs.append( + TokenLogProbs( + logprobs_by_token={result.text: result.logprobs[0]} + ) + ) state.tokens.append(result.token) if result.token == tokenizer.eot_id: @@ -377,7 +396,9 @@ class MetaReferenceInferenceImpl( self.check_model(request) # augment and rewrite messages depending on the model - request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) + request.messages = chat_completion_request_to_messages( + request, self.llama_model.core_model_id.value + ) # download media and convert to raw content so we can send it to the model request = await convert_request_to_raw(request) @@ -422,7 +443,9 @@ class MetaReferenceInferenceImpl( self.check_model(request) # augment and rewrite messages depending on the model - request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) + request.messages = chat_completion_request_to_messages( + request, self.llama_model.core_model_id.value + ) # download media and convert to raw content so we can send it to the model request = await convert_request_to_raw(request) request_batch.append(request) @@ -466,7 +489,11 @@ class MetaReferenceInferenceImpl( state.finished = result.finished if first_request.logprobs: - state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) + state.logprobs.append( + TokenLogProbs( + logprobs_by_token={result.text: result.logprobs[0]} + ) + ) state.tokens.append(result.token) if result.token == tokenizer.eot_id: @@ -479,7 +506,9 @@ class MetaReferenceInferenceImpl( if state.stop_reason is None: state.stop_reason = StopReason.out_of_tokens - raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason) + raw_message = self.generator.formatter.decode_assistant_message( + state.tokens, state.stop_reason + ) results.append( ChatCompletionResponse( completion_message=CompletionMessage( @@ -499,7 +528,9 @@ class MetaReferenceInferenceImpl( else: return impl() - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: tokenizer = self.generator.formatter.tokenizer def impl(): @@ -534,7 +565,13 @@ class MetaReferenceInferenceImpl( if request.logprobs: assert len(token_result.logprobs) == 1 - logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) + logprobs.append( + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ) tokens.append(token_result.token) @@ -572,7 +609,13 @@ class MetaReferenceInferenceImpl( if request.logprobs: assert len(token_result.logprobs) == 1 - logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) + logprobs.append( + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -585,7 +628,9 @@ class MetaReferenceInferenceImpl( if stop_reason is None: stop_reason = StopReason.out_of_tokens - message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) + message = self.generator.formatter.decode_assistant_message( + tokens, stop_reason + ) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 9ffcf99fe..0cfd12705 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -28,15 +28,15 @@ from fairscale.nn.model_parallel.initialize import ( get_model_parallel_rank, get_model_parallel_src_rank, ) -from pydantic import BaseModel, Field -from torch.distributed.launcher.api import LaunchConfig, elastic_launch -from typing_extensions import Annotated from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, ) +from pydantic import BaseModel, Field +from torch.distributed.launcher.api import elastic_launch, LaunchConfig +from typing_extensions import Annotated log = logging.getLogger(__name__) @@ -52,33 +52,51 @@ class ProcessingMessageName(str, Enum): class ReadyRequest(BaseModel): - type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request + type: Literal[ProcessingMessageName.ready_request] = ( + ProcessingMessageName.ready_request + ) class ReadyResponse(BaseModel): - type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response + type: Literal[ProcessingMessageName.ready_response] = ( + ProcessingMessageName.ready_response + ) class EndSentinel(BaseModel): - type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel + type: Literal[ProcessingMessageName.end_sentinel] = ( + ProcessingMessageName.end_sentinel + ) class CancelSentinel(BaseModel): - type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel + type: Literal[ProcessingMessageName.cancel_sentinel] = ( + ProcessingMessageName.cancel_sentinel + ) class TaskRequest(BaseModel): - type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]] + type: Literal[ProcessingMessageName.task_request] = ( + ProcessingMessageName.task_request + ) + task: Tuple[ + str, + List[CompletionRequestWithRawContent] + | List[ChatCompletionRequestWithRawContent], + ] class TaskResponse(BaseModel): - type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response + type: Literal[ProcessingMessageName.task_response] = ( + ProcessingMessageName.task_response + ) result: List[GenerationResult] class ExceptionResponse(BaseModel): - type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response + type: Literal[ProcessingMessageName.exception_response] = ( + ProcessingMessageName.exception_response + ) error: str @@ -172,7 +190,9 @@ def retrieve_requests(reply_socket_url: str): group=get_model_parallel_group(), ) if isinstance(updates[0], CancelSentinel): - log.info("quitting generation loop because request was cancelled") + log.info( + "quitting generation loop because request was cancelled" + ) break if mp_rank_0(): @@ -234,7 +254,7 @@ def worker_process_entrypoint( if isinstance(task, EndSentinel): break - assert isinstance(task, TaskRequest) + assert isinstance(task, TaskRequest), task result = model(task.task) except StopIteration: break @@ -331,7 +351,11 @@ class ModelParallelProcessGroup: def run_inference( self, - req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]], + req: Tuple[ + str, + List[CompletionRequestWithRawContent] + | List[ChatCompletionRequestWithRawContent], + ], ) -> Generator: assert not self.running, "inference already running"