Add an option to not use elastic agents for meta-reference inference (#269)

This commit is contained in:
Ashwin Bharambe 2024-10-18 12:51:10 -07:00 committed by GitHub
parent be3c5c034d
commit 33afd34e6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 33 additions and 8 deletions

View file

@ -18,6 +18,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
# there's a single model parallel process running serving the model. for now,
@ -36,8 +37,11 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
else:
self.generator = Llama.build(self.config)
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Dynamic model registration is not supported")
@ -51,7 +55,8 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
]
async def shutdown(self) -> None:
self.generator.stop()
if self.config.create_distributed_process_group:
self.generator.stop()
def completion(
self,
@ -99,8 +104,9 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
@ -110,7 +116,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with SEMAPHORE:
def impl():
messages = chat_completion_request_to_messages(request)
tokens = []
@ -154,10 +160,16 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with SEMAPHORE:
def impl():
messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk(
@ -272,6 +284,14 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def embeddings(
self,
model: str,