mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
refine
This commit is contained in:
parent
9e5b7d5c9e
commit
7ab807ad76
3 changed files with 17 additions and 2 deletions
|
@ -31,6 +31,7 @@ from llama_models.llama3.reference_impl.model import Transformer
|
|||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
CrossAttentionTransformer,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
@ -113,7 +114,13 @@ class Llama:
|
|||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
|
|
|
@ -107,6 +107,7 @@ class MetaReferenceInferenceImpl(
|
|||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
print("model metadata", model.metadata["llama_model"])
|
||||
llama_model = (
|
||||
resolve_model(model.metadata["llama_model"])
|
||||
if "llama_model" in model.metadata
|
||||
|
|
|
@ -12,6 +12,7 @@ from typing import Any, Generator
|
|||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Model
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
|
||||
|
@ -66,7 +67,13 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
checkpoint_dir = model_checkpoint_dir(self.model_id)
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
checkpoint_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue