mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
Add an option to not use elastic agents for meta-reference inference (#269)
This commit is contained in:
parent
be3c5c034d
commit
33afd34e6f
2 changed files with 33 additions and 8 deletions
|
@ -17,13 +17,18 @@ from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
class MetaReferenceInferenceConfig(BaseModel):
|
class MetaReferenceInferenceConfig(BaseModel):
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="Llama3.1-8B-Instruct",
|
default="Llama3.2-3B-Instruct",
|
||||||
description="Model descriptor from `llama model list`",
|
description="Model descriptor from `llama model list`",
|
||||||
)
|
)
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
max_seq_len: int = 4096
|
max_seq_len: int = 4096
|
||||||
max_batch_size: int = 1
|
max_batch_size: int = 1
|
||||||
|
|
||||||
|
# when this is False, we assume that the distributed process group is setup by someone
|
||||||
|
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||||
|
# (including our testing code) who might be using llama-stack as a library.
|
||||||
|
create_distributed_process_group: bool = True
|
||||||
|
|
||||||
@field_validator("model")
|
@field_validator("model")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model(cls, model: str) -> str:
|
def validate_model(cls, model: str) -> str:
|
||||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
|
from .generation import Llama
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
# there's a single model parallel process running serving the model. for now,
|
# there's a single model parallel process running serving the model. for now,
|
||||||
|
@ -36,8 +37,11 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
print(f"Loading model `{self.model.descriptor()}`")
|
print(f"Loading model `{self.model.descriptor()}`")
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(self.config)
|
self.generator = LlamaModelParallelGenerator(self.config)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
|
else:
|
||||||
|
self.generator = Llama.build(self.config)
|
||||||
|
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
raise ValueError("Dynamic model registration is not supported")
|
raise ValueError("Dynamic model registration is not supported")
|
||||||
|
@ -51,6 +55,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
]
|
]
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -99,6 +104,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
if SEMAPHORE.locked():
|
if SEMAPHORE.locked():
|
||||||
raise RuntimeError("Only one concurrent request is supported")
|
raise RuntimeError("Only one concurrent request is supported")
|
||||||
|
|
||||||
|
@ -110,7 +116,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
async with SEMAPHORE:
|
def impl():
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
|
@ -154,10 +160,16 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
logprobs=logprobs if request.logprobs else None,
|
logprobs=logprobs if request.logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
|
async with SEMAPHORE:
|
||||||
|
return impl()
|
||||||
|
else:
|
||||||
|
return impl()
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
async with SEMAPHORE:
|
def impl():
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request)
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
@ -272,6 +284,14 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
|
async with SEMAPHORE:
|
||||||
|
for x in impl():
|
||||||
|
yield x
|
||||||
|
else:
|
||||||
|
for x in impl():
|
||||||
|
yield x
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue