This commit is contained in:
Botao Chen 2024-12-17 13:38:19 -08:00
parent 415b8f2dbd
commit 48482ff9c3
9 changed files with 18 additions and 57 deletions

View file

View file

View file

@ -32,7 +32,6 @@ def get_impl_api(p: Any) -> Api:
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p) api = get_impl_api(p)
print("registering object with provider", api)
assert obj.provider_id != "remote", "Remote provider should not be registered" assert obj.provider_id != "remote", "Remote provider should not be registered"
@ -170,7 +169,6 @@ class CommonRoutingTableImpl(RoutingTable):
async def register_object( async def register_object(
self, obj: RoutableObjectWithProvider self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider: ) -> RoutableObjectWithProvider:
# Get existing objects from registry # Get existing objects from registry
existing_obj = await self.dist_registry.get(obj.type, obj.identifier) existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
@ -183,12 +181,7 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id] p = self.impls_by_provider_id[obj.provider_id]
if obj is None:
print("obj is None")
registered_obj = await register_object_with_provider(obj, p) registered_obj = await register_object_with_provider(obj, p)
if registered_obj is None:
print("registered_obj is None")
# TODO: This needs to be fixed for all APIs once they return the registered object # TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value: if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj) await self.dist_registry.register(registered_obj)
@ -218,7 +211,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> Model: ) -> Model:
print("register_model", model_id)
if provider_model_id is None: if provider_model_id is None:
provider_model_id = model_id provider_model_id = model_id
if provider_id is None: if provider_id is None:
@ -244,11 +236,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata=metadata, metadata=metadata,
model_type=model_type, model_type=model_type,
) )
if model is None:
print("model is None!!!")
print("before registered_model")
registered_model = await self.register_object(model) registered_model = await self.register_object(model)
print("after registered_model")
return registered_model return registered_model
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:

View file

@ -7,7 +7,6 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403 from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403 from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
@ -16,9 +15,11 @@ from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel): class MetaReferenceInferenceConfig(BaseModel):
model: Optional[str] = ( # this is a placeholder to indicate inference model id
None # this is a placeholder to indicate inference model id, not actually being used # the actual inference model id is dtermined by the moddel id in the request
) # Note: you need to register the model before using it for inference
# models in the resouce list in the run.yaml config will be registered automatically
model: Optional[str] = None
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
max_seq_len: int = 4096 max_seq_len: int = 4096
max_batch_size: int = 1 max_batch_size: int = 1
@ -45,13 +46,6 @@ class MetaReferenceInferenceConfig(BaseModel):
) )
return model return model
@property
def model_parallel_size(self) -> Optional[int]:
resolved = resolve_model(self.model)
if resolved is None:
return None
return resolved.pth_file_count
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,

View file

@ -62,7 +62,8 @@ def model_checkpoint_dir(model_id) -> str:
assert checkpoint_dir.exists(), ( assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. " f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"Please download model using `llama download --model-id {model_id}`" f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
) )
return str(checkpoint_dir) return str(checkpoint_dir)
@ -91,14 +92,9 @@ class Llama:
""" """
llama_model_id = llama_model.core_model_id.value llama_model_id = llama_model.core_model_id.value
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
print("I reach torch.distributed.init_process_group")
torch.distributed.init_process_group("nccl") torch.distributed.init_process_group("nccl")
model_parallel_size = ( model_parallel_size = llama_model.pth_file_count
config.model_parallel_size
if config.model_parallel_size
else llama_model.pth_file_count
)
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size) initialize_model_parallel(model_parallel_size)
@ -106,8 +102,6 @@ class Llama:
local_rank = int(os.environ.get("LOCAL_RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
print("torch.cuda.set_device")
# seed must be the same in all processes # seed must be the same in all processes
if config.torch_seed is not None: if config.torch_seed is not None:
torch.manual_seed(config.torch_seed) torch.manual_seed(config.torch_seed)

View file

@ -11,7 +11,7 @@ from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.models import Model as LlamaStackModel from llama_stack.apis.models import Model
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
@ -49,7 +49,6 @@ class MetaReferenceInferenceImpl(
async def initialize(self, model_id, llama_model) -> None: async def initialize(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`") log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
print("I reach create_distributed_process_group")
self.generator = LlamaModelParallelGenerator( self.generator = LlamaModelParallelGenerator(
self.config, model_id, llama_model self.config, model_id, llama_model
) )
@ -66,19 +65,17 @@ class MetaReferenceInferenceImpl(
def check_model(self, request) -> None: def check_model(self, request) -> None:
if self.model is None: if self.model is None:
raise RuntimeError( raise RuntimeError(
"Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first" "No avaible model yet, please register your requested model or add your model in the resouces first"
)
if request.model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
) )
elif request.model != self.model: elif request.model != self.model:
raise RuntimeError(f"Model mismatch: {request.model} != {self.model}") raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model}"
)
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel: async def register_model(self, model: Model) -> Model:
llama_model = ( llama_model = (
resolve_model(model.metadata["llama_model"]) resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata if "llama_model" in model.metadata
@ -102,11 +99,7 @@ class MetaReferenceInferenceImpl(
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id) self._load_sentence_transformer_model(model.provider_resource_id)
if ( if "skip_initialize" in model.metadata and model.metadata["skip_initialize"]:
model.metadata
and "skip_initialize" in model.metadata
and model.metadata["skip_initialize"]
):
return model return model
await self.initialize(model.identifier, llama_model) await self.initialize(model.identifier, llama_model)
return model return model

View file

@ -77,12 +77,7 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None) self.__exit__(None, None, None)
def __enter__(self): def __enter__(self):
if self.config.model_parallel_size: model_parallel_size = self.llama_model.pth_file_count
model_parallel_size = self.config.model_parallel_size
else:
model_parallel_size = self.llama_model.pth_file_count
print(f"model_parallel_size: {model_parallel_size}")
self.group = ModelParallelProcessGroup( self.group = ModelParallelProcessGroup(
model_parallel_size, model_parallel_size,

View file

@ -27,7 +27,8 @@ def supported_inference_models() -> List[Model]:
m m
for m in all_registered_models() for m in all_registered_models()
if ( if (
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2} m.model_family
in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
or is_supported_safety_model(m) or is_supported_safety_model(m)
) )
] ]

View file

@ -32,10 +32,6 @@ metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models: [] models: []
# - metadata: {}
# model_id: ${env.POST_TRAINING_MODEL}
# provider_id: meta-reference-inference
# provider_model_id: null
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: datasets: