mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Rename config var
This commit is contained in:
parent
bb3c26cfc5
commit
4bda515990
2 changed files with 6 additions and 6 deletions
|
@ -27,7 +27,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
# when this is False, we assume that the distributed process group is setup by someone
|
# when this is False, we assume that the distributed process group is setup by someone
|
||||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||||
# (including our testing code) who might be using llama-stack as a library.
|
# (including our testing code) who might be using llama-stack as a library.
|
||||||
use_elastic_agent: bool = True
|
create_distributed_process_group: bool = True
|
||||||
|
|
||||||
@field_validator("model")
|
@field_validator("model")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -37,7 +37,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
print(f"Loading model `{self.model.descriptor()}`")
|
print(f"Loading model `{self.model.descriptor()}`")
|
||||||
if self.config.use_elastic_agent:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(self.config)
|
self.generator = LlamaModelParallelGenerator(self.config)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
else:
|
else:
|
||||||
|
@ -55,7 +55,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
]
|
]
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self.config.use_elastic_agent:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -104,7 +104,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.use_elastic_agent:
|
if self.config.create_distributed_process_group:
|
||||||
if SEMAPHORE.locked():
|
if SEMAPHORE.locked():
|
||||||
raise RuntimeError("Only one concurrent request is supported")
|
raise RuntimeError("Only one concurrent request is supported")
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
logprobs=logprobs if request.logprobs else None,
|
logprobs=logprobs if request.logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.use_elastic_agent:
|
if self.config.create_distributed_process_group:
|
||||||
async with SEMAPHORE:
|
async with SEMAPHORE:
|
||||||
return impl()
|
return impl()
|
||||||
else:
|
else:
|
||||||
|
@ -284,7 +284,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.use_elastic_agent:
|
if self.config.create_distributed_process_group:
|
||||||
async with SEMAPHORE:
|
async with SEMAPHORE:
|
||||||
for x in impl():
|
for x in impl():
|
||||||
yield x
|
yield x
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue