mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
address commit
This commit is contained in:
parent
0f78a5fb2d
commit
29d0896ec8
1 changed files with 8 additions and 21 deletions
|
@ -38,27 +38,22 @@ class ModelConfig(BaseModel):
|
|||
checkpoint_type: str
|
||||
|
||||
|
||||
class ModelConfigs(BaseModel):
|
||||
Llama3_2_3B_Instruct: ModelConfig
|
||||
Llama_3_8B_Instruct: ModelConfig
|
||||
|
||||
|
||||
class DatasetSchema(BaseModel):
|
||||
alpaca: List[Dict[str, ParamType]]
|
||||
|
||||
|
||||
MODEL_CONFIGS = ModelConfigs(
|
||||
Llama3_2_3B_Instruct=ModelConfig(
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3_2",
|
||||
),
|
||||
Llama_3_8B_Instruct=ModelConfig(
|
||||
"Llama-3-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_8b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3",
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||
|
@ -85,14 +80,9 @@ BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
|||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
def _modify_model_id(model_id: str) -> str:
|
||||
return model_id.replace("-", "_").replace(".", "_")
|
||||
|
||||
|
||||
def _validate_model_id(model_id: str) -> Model:
|
||||
model = resolve_model(model_id)
|
||||
modified_model_id = _modify_model_id(model.core_model_id.value)
|
||||
if model is None or not hasattr(MODEL_CONFIGS, modified_model_id):
|
||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||
raise ValueError(f"Model {model_id} is not supported.")
|
||||
return model
|
||||
|
||||
|
@ -101,8 +91,7 @@ async def get_model_definition(
|
|||
model_id: str,
|
||||
) -> BuildLoraModelCallable:
|
||||
model = _validate_model_id(model_id)
|
||||
modified_model_id = _modify_model_id(model.core_model_id.value)
|
||||
model_config = getattr(MODEL_CONFIGS, modified_model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "model_definition"):
|
||||
raise ValueError(f"Model {model_id} does not have model definition.")
|
||||
return model_config.model_definition
|
||||
|
@ -112,8 +101,7 @@ async def get_tokenizer_type(
|
|||
model_id: str,
|
||||
) -> BuildTokenizerCallable:
|
||||
model = _validate_model_id(model_id)
|
||||
modified_model_id = _modify_model_id(model.core_model_id.value)
|
||||
model_config = getattr(MODEL_CONFIGS, modified_model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "tokenizer_type"):
|
||||
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
||||
return model_config.tokenizer_type
|
||||
|
@ -127,8 +115,7 @@ async def get_checkpointer_model_type(
|
|||
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
||||
"""
|
||||
model = _validate_model_id(model_id)
|
||||
modified_model_id = _modify_model_id(model.core_model_id.value)
|
||||
model_config = getattr(MODEL_CONFIGS, modified_model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "checkpoint_type"):
|
||||
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
||||
return model_config.checkpoint_type
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue