mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
refine
This commit is contained in:
parent
415b8f2dbd
commit
48482ff9c3
9 changed files with 18 additions and 57 deletions
0
=0.10.9
0
=0.10.9
0
==0.10.9
0
==0.10.9
|
@ -32,7 +32,6 @@ def get_impl_api(p: Any) -> Api:
|
|||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||
|
||||
api = get_impl_api(p)
|
||||
print("registering object with provider", api)
|
||||
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
|
@ -170,7 +169,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
|
||||
# Get existing objects from registry
|
||||
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
|
||||
|
||||
|
@ -183,12 +181,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
if obj is None:
|
||||
print("obj is None")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
if registered_obj is None:
|
||||
print("registered_obj is None")
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
|
@ -218,7 +211,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> Model:
|
||||
print("register_model", model_id)
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
if provider_id is None:
|
||||
|
@ -244,11 +236,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
)
|
||||
if model is None:
|
||||
print("model is None!!!")
|
||||
print("before registered_model")
|
||||
registered_model = await self.register_object(model)
|
||||
print("after registered_model")
|
||||
return registered_model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
@ -16,9 +15,11 @@ from llama_stack.providers.utils.inference import supported_inference_models
|
|||
|
||||
|
||||
class MetaReferenceInferenceConfig(BaseModel):
|
||||
model: Optional[str] = (
|
||||
None # this is a placeholder to indicate inference model id, not actually being used
|
||||
)
|
||||
# this is a placeholder to indicate inference model id
|
||||
# the actual inference model id is dtermined by the moddel id in the request
|
||||
# Note: you need to register the model before using it for inference
|
||||
# models in the resouce list in the run.yaml config will be registered automatically
|
||||
model: Optional[str] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
|
@ -45,13 +46,6 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def model_parallel_size(self) -> Optional[int]:
|
||||
resolved = resolve_model(self.model)
|
||||
if resolved is None:
|
||||
return None
|
||||
return resolved.pth_file_count
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
|
|
|
@ -62,7 +62,8 @@ def model_checkpoint_dir(model_id) -> str:
|
|||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"Please download model using `llama download --model-id {model_id}`"
|
||||
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
||||
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
@ -91,14 +92,9 @@ class Llama:
|
|||
"""
|
||||
llama_model_id = llama_model.core_model_id.value
|
||||
if not torch.distributed.is_initialized():
|
||||
print("I reach torch.distributed.init_process_group")
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
model_parallel_size = (
|
||||
config.model_parallel_size
|
||||
if config.model_parallel_size
|
||||
else llama_model.pth_file_count
|
||||
)
|
||||
model_parallel_size = llama_model.pth_file_count
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
@ -106,8 +102,6 @@ class Llama:
|
|||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
print("torch.cuda.set_device")
|
||||
|
||||
# seed must be the same in all processes
|
||||
if config.torch_seed is not None:
|
||||
torch.manual_seed(config.torch_seed)
|
||||
|
|
|
@ -11,7 +11,7 @@ from typing import AsyncGenerator, List
|
|||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.models import Model as LlamaStackModel
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
|
@ -49,7 +49,6 @@ class MetaReferenceInferenceImpl(
|
|||
async def initialize(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
if self.config.create_distributed_process_group:
|
||||
print("I reach create_distributed_process_group")
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
self.config, model_id, llama_model
|
||||
)
|
||||
|
@ -66,19 +65,17 @@ class MetaReferenceInferenceImpl(
|
|||
def check_model(self, request) -> None:
|
||||
if self.model is None:
|
||||
raise RuntimeError(
|
||||
"Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
if request.model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model:
|
||||
raise RuntimeError(f"Model mismatch: {request.model} != {self.model}")
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: request model: {request.model} != loaded model: {self.model}"
|
||||
)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
llama_model = (
|
||||
resolve_model(model.metadata["llama_model"])
|
||||
if "llama_model" in model.metadata
|
||||
|
@ -102,11 +99,7 @@ class MetaReferenceInferenceImpl(
|
|||
if model.model_type == ModelType.embedding:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
||||
if (
|
||||
model.metadata
|
||||
and "skip_initialize" in model.metadata
|
||||
and model.metadata["skip_initialize"]
|
||||
):
|
||||
if "skip_initialize" in model.metadata and model.metadata["skip_initialize"]:
|
||||
return model
|
||||
await self.initialize(model.identifier, llama_model)
|
||||
return model
|
||||
|
|
|
@ -77,12 +77,7 @@ class LlamaModelParallelGenerator:
|
|||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
if self.config.model_parallel_size:
|
||||
model_parallel_size = self.config.model_parallel_size
|
||||
else:
|
||||
model_parallel_size = self.llama_model.pth_file_count
|
||||
|
||||
print(f"model_parallel_size: {model_parallel_size}")
|
||||
model_parallel_size = self.llama_model.pth_file_count
|
||||
|
||||
self.group = ModelParallelProcessGroup(
|
||||
model_parallel_size,
|
||||
|
|
|
@ -27,7 +27,8 @@ def supported_inference_models() -> List[Model]:
|
|||
m
|
||||
for m in all_registered_models()
|
||||
if (
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
m.model_family
|
||||
in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
|
||||
or is_supported_safety_model(m)
|
||||
)
|
||||
]
|
||||
|
|
|
@ -32,10 +32,6 @@ metadata_store:
|
|||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||
models: []
|
||||
# - metadata: {}
|
||||
# model_id: ${env.POST_TRAINING_MODEL}
|
||||
# provider_id: meta-reference-inference
|
||||
# provider_model_id: null
|
||||
shields: []
|
||||
memory_banks: []
|
||||
datasets:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue