Add an option to not use elastic agents for meta-reference inference (#269)

This commit is contained in:
Ashwin Bharambe 2024-10-18 12:51:10 -07:00 committed by GitHub
parent be3c5c034d
commit 33afd34e6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 33 additions and 8 deletions

View file

@ -17,13 +17,18 @@ from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel): class MetaReferenceInferenceConfig(BaseModel):
model: str = Field( model: str = Field(
default="Llama3.1-8B-Instruct", default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
) )
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
max_seq_len: int = 4096 max_seq_len: int = 4096
max_batch_size: int = 1 max_batch_size: int = 1
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
# (including our testing code) who might be using llama-stack as a library.
create_distributed_process_group: bool = True
@field_validator("model") @field_validator("model")
@classmethod @classmethod
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:

View file

@ -18,6 +18,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
) )
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
# there's a single model parallel process running serving the model. for now, # there's a single model parallel process running serving the model. for now,
@ -36,8 +37,11 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None: async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`") print(f"Loading model `{self.model.descriptor()}`")
self.generator = LlamaModelParallelGenerator(self.config) if self.config.create_distributed_process_group:
self.generator.start() self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
else:
self.generator = Llama.build(self.config)
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
raise ValueError("Dynamic model registration is not supported") raise ValueError("Dynamic model registration is not supported")
@ -51,7 +55,8 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
] ]
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.generator.stop() if self.config.create_distributed_process_group:
self.generator.stop()
def completion( def completion(
self, self,
@ -99,8 +104,9 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
f"Model mismatch: {request.model} != {self.model.descriptor()}" f"Model mismatch: {request.model} != {self.model.descriptor()}"
) )
if SEMAPHORE.locked(): if self.config.create_distributed_process_group:
raise RuntimeError("Only one concurrent request is supported") if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream: if request.stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request)
@ -110,7 +116,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
async with SEMAPHORE: def impl():
messages = chat_completion_request_to_messages(request) messages = chat_completion_request_to_messages(request)
tokens = [] tokens = []
@ -154,10 +160,16 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs if request.logprobs else None, logprobs=logprobs if request.logprobs else None,
) )
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
async with SEMAPHORE: def impl():
messages = chat_completion_request_to_messages(request) messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
@ -272,6 +284,14 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
) )
) )
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def embeddings( async def embeddings(
self, self,
model: str, model: str,