This commit is contained in:
Eric Huang 2025-04-24 12:03:55 -07:00
parent a5d6ab16b2
commit 693c709c27
2 changed files with 103 additions and 34 deletions

View file

@ -8,9 +8,6 @@ import asyncio
import os import os
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
TextDelta, TextDelta,
ToolCallDelta, ToolCallDelta,
@ -55,8 +52,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry, build_hf_repo_model_entry,
ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin,
@ -68,6 +65,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
convert_request_to_raw, convert_request_to_raw,
) )
from pydantic import BaseModel
from termcolor import cprint
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
@ -78,7 +78,9 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: def llama_builder_fn(
config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model
) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model) return LlamaGenerator(config, model_id, llama_model)
@ -143,7 +145,8 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator( self.generator = LlamaModelParallelGenerator(
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count, model_parallel_size=self.config.model_parallel_size
or llama_model.pth_file_count,
builder_fn=llama_builder_fn, builder_fn=llama_builder_fn,
builder_params=builder_params, builder_params=builder_params,
formatter=( formatter=(
@ -178,7 +181,9 @@ class MetaReferenceInferenceImpl(
"No avaible model yet, please register your requested model or add your model in the resouces first" "No avaible model yet, please register your requested model or add your model in the resouces first"
) )
elif request.model != self.model_id: elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}") raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
async def completion( async def completion(
self, self,
@ -227,7 +232,8 @@ class MetaReferenceInferenceImpl(
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [ content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch augment_content_with_response_format_prompt(response_format, content)
for content in content_batch
] ]
request_batch = [] request_batch = []
@ -253,7 +259,8 @@ class MetaReferenceInferenceImpl(
def impl(): def impl():
stop_reason = None stop_reason = None
for token_result in self.generator.completion(request): for token_results in self.generator.completion([request]):
token_result = token_results[0]
if token_result.token == tokenizer.eot_id: if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
text = "" text = ""
@ -268,7 +275,13 @@ class MetaReferenceInferenceImpl(
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})] logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=text, delta=text,
@ -290,7 +303,9 @@ class MetaReferenceInferenceImpl(
for x in impl(): for x in impl():
yield x yield x
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]: async def _nonstream_completion(
self, request_batch: List[CompletionRequest]
) -> List[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0] first_request = request_batch[0]
@ -314,7 +329,11 @@ class MetaReferenceInferenceImpl(
state.finished = result.finished state.finished = result.finished
if first_request.logprobs: if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) state.logprobs.append(
TokenLogProbs(
logprobs_by_token={result.text: result.logprobs[0]}
)
)
state.tokens.append(result.token) state.tokens.append(result.token)
if result.token == tokenizer.eot_id: if result.token == tokenizer.eot_id:
@ -377,7 +396,9 @@ class MetaReferenceInferenceImpl(
self.check_model(request) self.check_model(request)
# augment and rewrite messages depending on the model # augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model # download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
@ -422,7 +443,9 @@ class MetaReferenceInferenceImpl(
self.check_model(request) self.check_model(request)
# augment and rewrite messages depending on the model # augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model # download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
request_batch.append(request) request_batch.append(request)
@ -466,7 +489,11 @@ class MetaReferenceInferenceImpl(
state.finished = result.finished state.finished = result.finished
if first_request.logprobs: if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) state.logprobs.append(
TokenLogProbs(
logprobs_by_token={result.text: result.logprobs[0]}
)
)
state.tokens.append(result.token) state.tokens.append(result.token)
if result.token == tokenizer.eot_id: if result.token == tokenizer.eot_id:
@ -479,7 +506,9 @@ class MetaReferenceInferenceImpl(
if state.stop_reason is None: if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason) raw_message = self.generator.formatter.decode_assistant_message(
state.tokens, state.stop_reason
)
results.append( results.append(
ChatCompletionResponse( ChatCompletionResponse(
completion_message=CompletionMessage( completion_message=CompletionMessage(
@ -499,7 +528,9 @@ class MetaReferenceInferenceImpl(
else: else:
return impl() return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
def impl(): def impl():
@ -534,7 +565,13 @@ class MetaReferenceInferenceImpl(
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
tokens.append(token_result.token) tokens.append(token_result.token)
@ -572,7 +609,13 @@ class MetaReferenceInferenceImpl(
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
@ -585,7 +628,9 @@ class MetaReferenceInferenceImpl(
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0 parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: if ipython and not parsed_tool_calls:

View file

@ -28,15 +28,15 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank, get_model_parallel_rank,
get_model_parallel_src_rank, get_model_parallel_src_rank,
) )
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
) )
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -52,33 +52,51 @@ class ProcessingMessageName(str, Enum):
class ReadyRequest(BaseModel): class ReadyRequest(BaseModel):
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request type: Literal[ProcessingMessageName.ready_request] = (
ProcessingMessageName.ready_request
)
class ReadyResponse(BaseModel): class ReadyResponse(BaseModel):
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response type: Literal[ProcessingMessageName.ready_response] = (
ProcessingMessageName.ready_response
)
class EndSentinel(BaseModel): class EndSentinel(BaseModel):
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel type: Literal[ProcessingMessageName.end_sentinel] = (
ProcessingMessageName.end_sentinel
)
class CancelSentinel(BaseModel): class CancelSentinel(BaseModel):
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel type: Literal[ProcessingMessageName.cancel_sentinel] = (
ProcessingMessageName.cancel_sentinel
)
class TaskRequest(BaseModel): class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request type: Literal[ProcessingMessageName.task_request] = (
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]] ProcessingMessageName.task_request
)
task: Tuple[
str,
List[CompletionRequestWithRawContent]
| List[ChatCompletionRequestWithRawContent],
]
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response type: Literal[ProcessingMessageName.task_response] = (
ProcessingMessageName.task_response
)
result: List[GenerationResult] result: List[GenerationResult]
class ExceptionResponse(BaseModel): class ExceptionResponse(BaseModel):
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response type: Literal[ProcessingMessageName.exception_response] = (
ProcessingMessageName.exception_response
)
error: str error: str
@ -172,7 +190,9 @@ def retrieve_requests(reply_socket_url: str):
group=get_model_parallel_group(), group=get_model_parallel_group(),
) )
if isinstance(updates[0], CancelSentinel): if isinstance(updates[0], CancelSentinel):
log.info("quitting generation loop because request was cancelled") log.info(
"quitting generation loop because request was cancelled"
)
break break
if mp_rank_0(): if mp_rank_0():
@ -234,7 +254,7 @@ def worker_process_entrypoint(
if isinstance(task, EndSentinel): if isinstance(task, EndSentinel):
break break
assert isinstance(task, TaskRequest) assert isinstance(task, TaskRequest), task
result = model(task.task) result = model(task.task)
except StopIteration: except StopIteration:
break break
@ -331,7 +351,11 @@ class ModelParallelProcessGroup:
def run_inference( def run_inference(
self, self,
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]], req: Tuple[
str,
List[CompletionRequestWithRawContent]
| List[ChatCompletionRequestWithRawContent],
],
) -> Generator: ) -> Generator:
assert not self.running, "inference already running" assert not self.running, "inference already running"