mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Add an option to not use elastic agents for meta-reference inference (#269)
This commit is contained in:
parent
be3c5c034d
commit
33afd34e6f
2 changed files with 33 additions and 8 deletions
|
@ -18,6 +18,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
|
@ -36,8 +37,11 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Loading model `{self.model.descriptor()}`")
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama.build(self.config)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Dynamic model registration is not supported")
|
||||
|
@ -51,7 +55,8 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
]
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
|
@ -99,8 +104,9 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
|
@ -110,7 +116,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
async with SEMAPHORE:
|
||||
def impl():
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
tokens = []
|
||||
|
@ -154,10 +160,16 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
async with SEMAPHORE:
|
||||
def impl():
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
@ -272,6 +284,14 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
for x in impl():
|
||||
yield x
|
||||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue