mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
temp commit
This commit is contained in:
parent
30f6eb282f
commit
81e1957446
10 changed files with 54 additions and 39 deletions
0
=0.10.9
Normal file
0
=0.10.9
Normal file
0
==0.10.9
Normal file
0
==0.10.9
Normal file
|
@ -46,8 +46,10 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_parallel_size(self) -> int:
|
def model_parallel_size(self) -> Optional[int]:
|
||||||
resolved = resolve_model(self.model)
|
resolved = resolve_model(self.model)
|
||||||
|
if resolved is None:
|
||||||
|
return None
|
||||||
return resolved.pth_file_count
|
return resolved.pth_file_count
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -25,12 +25,12 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||||
|
from llama_models.llama3.api.datatypes import Model
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
CrossAttentionTransformer,
|
CrossAttentionTransformer,
|
||||||
)
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
@ -53,16 +53,16 @@ from .config import (
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model_id) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
checkpoint_dir = Path(model_local_dir(model_id))
|
||||||
|
|
||||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||||
if not any(p.exists() for p in paths):
|
if not any(p.exists() for p in paths):
|
||||||
checkpoint_dir = checkpoint_dir / "original"
|
checkpoint_dir = checkpoint_dir / "original"
|
||||||
|
|
||||||
assert checkpoint_dir.exists(), (
|
assert checkpoint_dir.exists(), (
|
||||||
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||||
f"Please download model using `llama download --model-id {model.descriptor()}`"
|
f"Please download model using `llama download --model-id {model_id}`"
|
||||||
)
|
)
|
||||||
return str(checkpoint_dir)
|
return str(checkpoint_dir)
|
||||||
|
|
||||||
|
@ -80,6 +80,7 @@ class Llama:
|
||||||
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
],
|
],
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
llama_model: Model,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Build a Llama instance by initializing and loading a model checkpoint.
|
Build a Llama instance by initializing and loading a model checkpoint.
|
||||||
|
@ -88,14 +89,16 @@ class Llama:
|
||||||
This method initializes the distributed process group, sets the device to CUDA,
|
This method initializes the distributed process group, sets the device to CUDA,
|
||||||
and loads the pre-trained model and tokenizer.
|
and loads the pre-trained model and tokenizer.
|
||||||
"""
|
"""
|
||||||
model = resolve_model(model_id)
|
llama_model_id = llama_model.core_model_id.value
|
||||||
|
|
||||||
llama_model = model.core_model_id.value
|
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
|
print("I reach torch.distributed.init_process_group")
|
||||||
torch.distributed.init_process_group("nccl")
|
torch.distributed.init_process_group("nccl")
|
||||||
|
|
||||||
model_parallel_size = config.model_parallel_size
|
model_parallel_size = (
|
||||||
|
config.model_parallel_size
|
||||||
|
if config.model_parallel_size
|
||||||
|
else llama_model.pth_file_count
|
||||||
|
)
|
||||||
|
|
||||||
if not model_parallel_is_initialized():
|
if not model_parallel_is_initialized():
|
||||||
initialize_model_parallel(model_parallel_size)
|
initialize_model_parallel(model_parallel_size)
|
||||||
|
@ -103,6 +106,8 @@ class Llama:
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
|
|
||||||
|
print("torch.cuda.set_device")
|
||||||
|
|
||||||
# seed must be the same in all processes
|
# seed must be the same in all processes
|
||||||
if config.torch_seed is not None:
|
if config.torch_seed is not None:
|
||||||
torch.manual_seed(config.torch_seed)
|
torch.manual_seed(config.torch_seed)
|
||||||
|
@ -114,7 +119,7 @@ class Llama:
|
||||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||||
ckpt_dir = config.checkpoint_dir
|
ckpt_dir = config.checkpoint_dir
|
||||||
else:
|
else:
|
||||||
ckpt_dir = model_checkpoint_dir(model)
|
ckpt_dir = model_checkpoint_dir(model_id) # true model id
|
||||||
|
|
||||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
|
@ -190,7 +195,7 @@ class Llama:
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
return Llama(model, tokenizer, model_args, llama_model)
|
return Llama(model, tokenizer, model_args, llama_model_id)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -46,13 +46,16 @@ class MetaReferenceInferenceImpl(
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
async def initialize(self, model_id) -> None:
|
async def initialize(self, model_id, llama_model) -> None:
|
||||||
log.info(f"Loading model `{model_id}`")
|
log.info(f"Loading model `{model_id}`")
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(self.config, model_id)
|
print("I reach create_distributed_process_group")
|
||||||
|
self.generator = LlamaModelParallelGenerator(
|
||||||
|
self.config, model_id, llama_model
|
||||||
|
)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
else:
|
else:
|
||||||
self.generator = Llama.build(self.config, model_id)
|
self.generator = Llama.build(self.config, model_id, llama_model)
|
||||||
|
|
||||||
self.model = model_id
|
self.model = model_id
|
||||||
|
|
||||||
|
@ -65,26 +68,27 @@ class MetaReferenceInferenceImpl(
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first"
|
"Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first"
|
||||||
)
|
)
|
||||||
inference_model = resolve_model(self.model)
|
if request.model is None:
|
||||||
requested_model = resolve_model(request.model)
|
|
||||||
if requested_model is None:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Unknown model: {request.model}, Run `llama model list`"
|
f"Unknown model: {request.model}, Run `llama model list`"
|
||||||
)
|
)
|
||||||
elif requested_model.descriptor() != inference_model.descriptor():
|
elif request.model != self.model:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Model mismatch: {request.model} != {self.model}")
|
||||||
f"Model mismatch: {request.model} != {inference_model.descriptor()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
|
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
|
||||||
llama_model = resolve_model(model.identifier)
|
llama_model = (
|
||||||
|
resolve_model(model.metadata["llama_model"])
|
||||||
|
if "llama_model" in model.metadata
|
||||||
|
else resolve_model(model.identifier)
|
||||||
|
)
|
||||||
if llama_model is None:
|
if llama_model is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Unknown model: {model.identifier}, Please make sure your model is in llama-models SKU list"
|
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_registry_helper = ModelRegistryHelper(
|
self.model_registry_helper = ModelRegistryHelper(
|
||||||
[
|
[
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
|
@ -94,6 +98,7 @@ class MetaReferenceInferenceImpl(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
model = await self.model_registry_helper.register_model(model)
|
model = await self.model_registry_helper.register_model(model)
|
||||||
|
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
|
||||||
|
@ -103,7 +108,7 @@ class MetaReferenceInferenceImpl(
|
||||||
and model.metadata["skip_initialize"]
|
and model.metadata["skip_initialize"]
|
||||||
):
|
):
|
||||||
return model
|
return model
|
||||||
await self.initialize(model.identifier)
|
await self.initialize(model.identifier, llama_model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
|
|
@ -10,8 +10,8 @@ from functools import partial
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.datatypes import Model
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||||
|
|
||||||
|
@ -37,8 +37,9 @@ class ModelRunner:
|
||||||
def init_model_cb(
|
def init_model_cb(
|
||||||
config: MetaReferenceInferenceConfig,
|
config: MetaReferenceInferenceConfig,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
llama_model: Model,
|
||||||
):
|
):
|
||||||
llama = Llama.build(config, model_id)
|
llama = Llama.build(config, model_id, llama_model)
|
||||||
return ModelRunner(llama)
|
return ModelRunner(llama)
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,14 +58,15 @@ class LlamaModelParallelGenerator:
|
||||||
self,
|
self,
|
||||||
config: MetaReferenceInferenceConfig,
|
config: MetaReferenceInferenceConfig,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
llama_model: Model,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.model = resolve_model(model_id)
|
self.llama_model = llama_model
|
||||||
|
|
||||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||||
# while the tool-use loop is going
|
# while the tool-use loop is going
|
||||||
checkpoint_dir = model_checkpoint_dir(self.model)
|
checkpoint_dir = model_checkpoint_dir(self.model_id)
|
||||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||||
|
|
||||||
|
@ -78,11 +80,15 @@ class LlamaModelParallelGenerator:
|
||||||
if self.config.model_parallel_size:
|
if self.config.model_parallel_size:
|
||||||
model_parallel_size = self.config.model_parallel_size
|
model_parallel_size = self.config.model_parallel_size
|
||||||
else:
|
else:
|
||||||
model_parallel_size = resolve_model(self.model).pth_file_count
|
model_parallel_size = self.llama_model.pth_file_count
|
||||||
|
|
||||||
|
print(f"model_parallel_size: {model_parallel_size}")
|
||||||
|
|
||||||
self.group = ModelParallelProcessGroup(
|
self.group = ModelParallelProcessGroup(
|
||||||
model_parallel_size,
|
model_parallel_size,
|
||||||
init_model_cb=partial(init_model_cb, self.config, self.model_id),
|
init_model_cb=partial(
|
||||||
|
init_model_cb, self.config, self.model_id, self.llama_model
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.group.start()
|
self.group.start()
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -300,7 +300,7 @@ def start_model_parallel_process(
|
||||||
|
|
||||||
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
||||||
|
|
||||||
ctx = multiprocessing.get_context("fork")
|
ctx = multiprocessing.get_context("spawn")
|
||||||
process = ctx.Process(
|
process = ctx.Process(
|
||||||
target=launch_dist_group,
|
target=launch_dist_group,
|
||||||
args=(
|
args=(
|
||||||
|
|
|
@ -69,7 +69,6 @@ def pytest_generate_tests(metafunc):
|
||||||
else:
|
else:
|
||||||
params = MODEL_PARAMS
|
params = MODEL_PARAMS
|
||||||
|
|
||||||
# print("params", params)
|
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"inference_model",
|
"inference_model",
|
||||||
params,
|
params,
|
||||||
|
@ -83,7 +82,5 @@ def pytest_generate_tests(metafunc):
|
||||||
"inference": INFERENCE_FIXTURES,
|
"inference": INFERENCE_FIXTURES,
|
||||||
},
|
},
|
||||||
):
|
):
|
||||||
# print("I reach here")
|
|
||||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||||
print("fixtures", fixtures)
|
|
||||||
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
||||||
|
|
|
@ -9,8 +9,8 @@ import pytest
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
|
# torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct"
|
||||||
# -m "meta_reference"
|
# ./llama_stack/providers/tests/inference/test_model_registration.py
|
||||||
|
|
||||||
|
|
||||||
class TestModelRegistration:
|
class TestModelRegistration:
|
||||||
|
|
|
@ -75,7 +75,7 @@ metadata_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||||
models: []
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: meta-reference-inference
|
provider_id: meta-reference-inference
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue