Rename config var

This commit is contained in:
Ashwin Bharambe 2024-10-18 12:46:31 -07:00
parent bb3c26cfc5
commit 4bda515990
2 changed files with 6 additions and 6 deletions

View file

@ -27,7 +27,7 @@ class MetaReferenceInferenceConfig(BaseModel):
# when this is False, we assume that the distributed process group is setup by someone # when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients # outside of this code (e.g., when run inside `torchrun`). that is useful for clients
# (including our testing code) who might be using llama-stack as a library. # (including our testing code) who might be using llama-stack as a library.
use_elastic_agent: bool = True create_distributed_process_group: bool = True
@field_validator("model") @field_validator("model")
@classmethod @classmethod

View file

@ -37,7 +37,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None: async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`") print(f"Loading model `{self.model.descriptor()}`")
if self.config.use_elastic_agent: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start() self.generator.start()
else: else:
@ -55,7 +55,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
] ]
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self.config.use_elastic_agent: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()
def completion( def completion(
@ -104,7 +104,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
f"Model mismatch: {request.model} != {self.model.descriptor()}" f"Model mismatch: {request.model} != {self.model.descriptor()}"
) )
if self.config.use_elastic_agent: if self.config.create_distributed_process_group:
if SEMAPHORE.locked(): if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported") raise RuntimeError("Only one concurrent request is supported")
@ -160,7 +160,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs if request.logprobs else None, logprobs=logprobs if request.logprobs else None,
) )
if self.config.use_elastic_agent: if self.config.create_distributed_process_group:
async with SEMAPHORE: async with SEMAPHORE:
return impl() return impl()
else: else:
@ -284,7 +284,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
) )
) )
if self.config.use_elastic_agent: if self.config.create_distributed_process_group:
async with SEMAPHORE: async with SEMAPHORE:
for x in impl(): for x in impl():
yield x yield x