Rename config var

This commit is contained in:
Ashwin Bharambe 2024-10-18 12:46:31 -07:00
parent bb3c26cfc5
commit 4bda515990
2 changed files with 6 additions and 6 deletions

View file

@ -27,7 +27,7 @@ class MetaReferenceInferenceConfig(BaseModel):
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
# (including our testing code) who might be using llama-stack as a library.
use_elastic_agent: bool = True
create_distributed_process_group: bool = True
@field_validator("model")
@classmethod

View file

@ -37,7 +37,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
if self.config.use_elastic_agent:
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
else:
@ -55,7 +55,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
]
async def shutdown(self) -> None:
if self.config.use_elastic_agent:
if self.config.create_distributed_process_group:
self.generator.stop()
def completion(
@ -104,7 +104,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
if self.config.use_elastic_agent:
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
@ -160,7 +160,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs if request.logprobs else None,
)
if self.config.use_elastic_agent:
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
@ -284,7 +284,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
)
)
if self.config.use_elastic_agent:
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x