mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
address comments
This commit is contained in:
parent
d7d19dc0e5
commit
3378c100f6
1 changed files with 54 additions and 20 deletions
|
@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List
|
||||||
import torch
|
import torch
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.common.type_system import * # noqa
|
from llama_stack.apis.common.type_system import * # noqa
|
||||||
|
from llama_models.datatypes import Model
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
|
||||||
|
@ -31,18 +32,29 @@ class ColumnName(Enum):
|
||||||
text = "text"
|
text = "text"
|
||||||
|
|
||||||
|
|
||||||
MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
|
class ModelConfig(BaseModel):
|
||||||
"Llama3.2-3B-Instruct": {
|
model_definition: Any
|
||||||
"model_definition": lora_llama3_2_3b,
|
tokenizer_type: Any
|
||||||
"tokenizer_type": llama3_tokenizer,
|
checkpoint_type: str
|
||||||
"checkpoint_type": "LLAMA3_2",
|
|
||||||
},
|
|
||||||
"Llama-3-8B-Instruct": {
|
class ModelConfigs(BaseModel):
|
||||||
"model_definition": lora_llama3_8b,
|
Llama3_2_3B_Instruct: ModelConfig
|
||||||
"tokenizer_type": llama3_tokenizer,
|
Llama_3_8B_Instruct: ModelConfig
|
||||||
"checkpoint_type": "LLAMA3",
|
|
||||||
},
|
|
||||||
}
|
MODEL_CONFIGS = ModelConfigs(
|
||||||
|
Llama3_2_3B_Instruct=ModelConfig(
|
||||||
|
model_definition=lora_llama3_2_3b,
|
||||||
|
tokenizer_type=llama3_tokenizer,
|
||||||
|
checkpoint_type="LLAMA3_2",
|
||||||
|
),
|
||||||
|
Llama_3_8B_Instruct=ModelConfig(
|
||||||
|
model_definition=lora_llama3_8b,
|
||||||
|
tokenizer_type=llama3_tokenizer,
|
||||||
|
checkpoint_type="LLAMA3",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = {
|
EXPECTED_DATASET_SCHEMA: Dict[str, List[Dict[str, ParamType]]] = {
|
||||||
"alpaca": [
|
"alpaca": [
|
||||||
|
@ -68,20 +80,38 @@ BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||||
|
|
||||||
|
|
||||||
|
def _modify_model_id(model_id: str) -> str:
|
||||||
|
return model_id.replace("-", "_").replace(".", "_")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_model_id(model_id: str) -> Model:
|
||||||
|
model = resolve_model(model_id)
|
||||||
|
modified_model_id = _modify_model_id(model.core_model_id.value)
|
||||||
|
if model is None or not hasattr(MODEL_CONFIGS, modified_model_id):
|
||||||
|
raise ValueError(f"Model {model_id} is not supported.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
async def get_model_definition(
|
async def get_model_definition(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
) -> BuildLoraModelCallable:
|
) -> BuildLoraModelCallable:
|
||||||
model = resolve_model(model_id)
|
model = _validate_model_id(model_id)
|
||||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
modified_model_id = _modify_model_id(model.core_model_id.value)
|
||||||
raise ValueError(f"Model {model_id} is not supported.")
|
model_config = getattr(MODEL_CONFIGS, modified_model_id)
|
||||||
return MODEL_CONFIGS[model.core_model_id.value]["model_definition"]
|
if not hasattr(model_config, "model_definition"):
|
||||||
|
raise ValueError(f"Model {model_id} does not have model definition.")
|
||||||
|
return model_config.model_definition
|
||||||
|
|
||||||
|
|
||||||
async def get_tokenizer_type(
|
async def get_tokenizer_type(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
) -> BuildTokenizerCallable:
|
) -> BuildTokenizerCallable:
|
||||||
model = resolve_model(model_id)
|
model = _validate_model_id(model_id)
|
||||||
return MODEL_CONFIGS[model.core_model_id.value]["tokenizer_type"]
|
modified_model_id = _modify_model_id(model.core_model_id.value)
|
||||||
|
model_config = getattr(MODEL_CONFIGS, modified_model_id)
|
||||||
|
if not hasattr(model_config, "tokenizer_type"):
|
||||||
|
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
||||||
|
return model_config.tokenizer_type
|
||||||
|
|
||||||
|
|
||||||
async def get_checkpointer_model_type(
|
async def get_checkpointer_model_type(
|
||||||
|
@ -91,8 +121,12 @@ async def get_checkpointer_model_type(
|
||||||
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
||||||
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
||||||
"""
|
"""
|
||||||
model = resolve_model(model_id)
|
model = _validate_model_id(model_id)
|
||||||
return MODEL_CONFIGS[model.core_model_id.value]["checkpoint_type"]
|
modified_model_id = _modify_model_id(model.core_model_id.value)
|
||||||
|
model_config = getattr(MODEL_CONFIGS, modified_model_id)
|
||||||
|
if not hasattr(model_config, "checkpoint_type"):
|
||||||
|
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
||||||
|
return model_config.checkpoint_type
|
||||||
|
|
||||||
|
|
||||||
async def validate_input_dataset_schema(
|
async def validate_input_dataset_schema(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue