This commit is contained in:
Botao Chen 2024-12-17 13:38:19 -08:00
parent 415b8f2dbd
commit 48482ff9c3
9 changed files with 18 additions and 57 deletions

View file

@ -7,7 +7,6 @@
from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, field_validator
@ -16,9 +15,11 @@ from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel):
model: Optional[str] = (
None # this is a placeholder to indicate inference model id, not actually being used
)
# this is a placeholder to indicate inference model id
# the actual inference model id is dtermined by the moddel id in the request
# Note: you need to register the model before using it for inference
# models in the resouce list in the run.yaml config will be registered automatically
model: Optional[str] = None
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
@ -45,13 +46,6 @@ class MetaReferenceInferenceConfig(BaseModel):
)
return model
@property
def model_parallel_size(self) -> Optional[int]:
resolved = resolve_model(self.model)
if resolved is None:
return None
return resolved.pth_file_count
@classmethod
def sample_run_config(
cls,

View file

@ -62,7 +62,8 @@ def model_checkpoint_dir(model_id) -> str:
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"Please download model using `llama download --model-id {model_id}`"
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)
@ -91,14 +92,9 @@ class Llama:
"""
llama_model_id = llama_model.core_model_id.value
if not torch.distributed.is_initialized():
print("I reach torch.distributed.init_process_group")
torch.distributed.init_process_group("nccl")
model_parallel_size = (
config.model_parallel_size
if config.model_parallel_size
else llama_model.pth_file_count
)
model_parallel_size = llama_model.pth_file_count
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
@ -106,8 +102,6 @@ class Llama:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
print("torch.cuda.set_device")
# seed must be the same in all processes
if config.torch_seed is not None:
torch.manual_seed(config.torch_seed)

View file

@ -11,7 +11,7 @@ from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import Model as LlamaStackModel
from llama_stack.apis.models import Model
from llama_models.llama3.api.datatypes import * # noqa: F403
@ -49,7 +49,6 @@ class MetaReferenceInferenceImpl(
async def initialize(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group:
print("I reach create_distributed_process_group")
self.generator = LlamaModelParallelGenerator(
self.config, model_id, llama_model
)
@ -66,19 +65,17 @@ class MetaReferenceInferenceImpl(
def check_model(self, request) -> None:
if self.model is None:
raise RuntimeError(
"Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first"
)
if request.model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model:
raise RuntimeError(f"Model mismatch: {request.model} != {self.model}")
raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model}"
)
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
async def register_model(self, model: Model) -> Model:
llama_model = (
resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata
@ -102,11 +99,7 @@ class MetaReferenceInferenceImpl(
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
if (
model.metadata
and "skip_initialize" in model.metadata
and model.metadata["skip_initialize"]
):
if "skip_initialize" in model.metadata and model.metadata["skip_initialize"]:
return model
await self.initialize(model.identifier, llama_model)
return model

View file

@ -77,12 +77,7 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None)
def __enter__(self):
if self.config.model_parallel_size:
model_parallel_size = self.config.model_parallel_size
else:
model_parallel_size = self.llama_model.pth_file_count
print(f"model_parallel_size: {model_parallel_size}")
model_parallel_size = self.llama_model.pth_file_count
self.group = ModelParallelProcessGroup(
model_parallel_size,