From 7ab807ad76b3650ff37c9917ebb965b5a7aee5be Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 18 Dec 2024 15:58:51 -0800 Subject: [PATCH] refine --- .../inline/inference/meta_reference/generation.py | 9 ++++++++- .../inline/inference/meta_reference/inference.py | 1 + .../inline/inference/meta_reference/model_parallel.py | 9 ++++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 28203e92e..c89183cb7 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -31,6 +31,7 @@ from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) +from llama_models.sku_list import resolve_model from pydantic import BaseModel from llama_stack.apis.inference import * # noqa: F403 @@ -113,7 +114,13 @@ class Llama: if config.checkpoint_dir and config.checkpoint_dir != "null": ckpt_dir = config.checkpoint_dir else: - ckpt_dir = model_checkpoint_dir(model_id) + resolved_model = resolve_model(model_id) + if resolved_model is None: + # if the model is not a native llama model, get the default checkpoint_dir based on model id + ckpt_dir = model_checkpoint_dir(model_id) + else: + # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value + ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index d89bb21f7..1ea6d3e6a 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -107,6 +107,7 @@ class MetaReferenceInferenceImpl( pass async def register_model(self, model: Model) -> Model: + print("model metadata", model.metadata["llama_model"]) llama_model = ( resolve_model(model.metadata["llama_model"]) if "llama_model" in model.metadata diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index e7a05a1e1..cb422b9b6 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -12,6 +12,7 @@ from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.sku_list import resolve_model from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest @@ -66,7 +67,13 @@ class LlamaModelParallelGenerator: # this is a hack because Agent's loop uses this to tokenize and check if input is too long # while the tool-use loop is going - checkpoint_dir = model_checkpoint_dir(self.model_id) + resolved_model = resolve_model(model_id) + if resolved_model is None: + # if the model is not a native llama model, get the default checkpoint_dir based on model id + checkpoint_dir = model_checkpoint_dir(model_id) + else: + # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value + checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor()) tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model") self.formatter = ChatFormat(Tokenizer(tokenizer_path))