temp commit

This commit is contained in:
Botao Chen 2024-12-12 21:44:03 -08:00
parent 8efe33646d
commit de44af1501
9 changed files with 153 additions and 53 deletions

View file

@ -79,6 +79,7 @@ class Llama:
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
],
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None,
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
@ -87,10 +88,13 @@ class Llama:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
model = await self.model_store.get_model(config.model)
base_model = model.metadata["base_model"] or self.model_id
self.model = resolve_model(base_model)
model = resolve_model(config.model)
if config.model:
model = resolve_model(config.model)
elif request:
model = resolve_model(request.model)
else:
raise RuntimeError("you need to provide a model for inference")
llama_model = model.core_model_id.value
if not torch.distributed.is_initialized():