This commit is contained in:
Botao Chen 2024-12-18 15:58:51 -08:00
parent 9e5b7d5c9e
commit 7ab807ad76
3 changed files with 17 additions and 2 deletions

View file

@ -31,6 +31,7 @@ from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from llama_stack.apis.inference import * # noqa: F403
@ -113,7 +114,13 @@ class Llama:
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
ckpt_dir = model_checkpoint_dir(model_id)
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"

View file

@ -107,6 +107,7 @@ class MetaReferenceInferenceImpl(
pass
async def register_model(self, model: Model) -> Model:
print("model metadata", model.metadata["llama_model"])
llama_model = (
resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata

View file

@ -12,6 +12,7 @@ from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
@ -66,7 +67,13 @@ class LlamaModelParallelGenerator:
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going
checkpoint_dir = model_checkpoint_dir(self.model_id)
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
checkpoint_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor())
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
self.formatter = ChatFormat(Tokenizer(tokenizer_path))