diff --git a/llama_stack/providers/impls/meta_reference/inference/config.py b/llama_stack/providers/impls/meta_reference/inference/config.py index a316c733a..4e1161ced 100644 --- a/llama_stack/providers/impls/meta_reference/inference/config.py +++ b/llama_stack/providers/impls/meta_reference/inference/config.py @@ -27,7 +27,7 @@ class MetaReferenceInferenceConfig(BaseModel): # when this is False, we assume that the distributed process group is setup by someone # outside of this code (e.g., when run inside `torchrun`). that is useful for clients # (including our testing code) who might be using llama-stack as a library. - use_elastic_agent: bool = True + create_distributed_process_group: bool = True @field_validator("model") @classmethod diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 55d5de72c..7edc279d0 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -37,7 +37,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): async def initialize(self) -> None: print(f"Loading model `{self.model.descriptor()}`") - if self.config.use_elastic_agent: + if self.config.create_distributed_process_group: self.generator = LlamaModelParallelGenerator(self.config) self.generator.start() else: @@ -55,7 +55,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): ] async def shutdown(self) -> None: - if self.config.use_elastic_agent: + if self.config.create_distributed_process_group: self.generator.stop() def completion( @@ -104,7 +104,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): f"Model mismatch: {request.model} != {self.model.descriptor()}" ) - if self.config.use_elastic_agent: + if self.config.create_distributed_process_group: if SEMAPHORE.locked(): raise RuntimeError("Only one concurrent request is supported") @@ -160,7 +160,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): logprobs=logprobs if request.logprobs else None, ) - if self.config.use_elastic_agent: + if self.config.create_distributed_process_group: async with SEMAPHORE: return impl() else: @@ -284,7 +284,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): ) ) - if self.config.use_elastic_agent: + if self.config.create_distributed_process_group: async with SEMAPHORE: for x in impl(): yield x