mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
refine
This commit is contained in:
parent
9e5b7d5c9e
commit
7ab807ad76
3 changed files with 17 additions and 2 deletions
|
@ -31,6 +31,7 @@ from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
CrossAttentionTransformer,
|
CrossAttentionTransformer,
|
||||||
)
|
)
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
@ -113,7 +114,13 @@ class Llama:
|
||||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||||
ckpt_dir = config.checkpoint_dir
|
ckpt_dir = config.checkpoint_dir
|
||||||
else:
|
else:
|
||||||
|
resolved_model = resolve_model(model_id)
|
||||||
|
if resolved_model is None:
|
||||||
|
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||||
ckpt_dir = model_checkpoint_dir(model_id)
|
ckpt_dir = model_checkpoint_dir(model_id)
|
||||||
|
else:
|
||||||
|
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||||
|
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||||
|
|
||||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
|
|
|
@ -107,6 +107,7 @@ class MetaReferenceInferenceImpl(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
print("model metadata", model.metadata["llama_model"])
|
||||||
llama_model = (
|
llama_model = (
|
||||||
resolve_model(model.metadata["llama_model"])
|
resolve_model(model.metadata["llama_model"])
|
||||||
if "llama_model" in model.metadata
|
if "llama_model" in model.metadata
|
||||||
|
|
|
@ -12,6 +12,7 @@ from typing import Any, Generator
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Model
|
from llama_models.llama3.api.datatypes import Model
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||||
|
|
||||||
|
@ -66,7 +67,13 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||||
# while the tool-use loop is going
|
# while the tool-use loop is going
|
||||||
checkpoint_dir = model_checkpoint_dir(self.model_id)
|
resolved_model = resolve_model(model_id)
|
||||||
|
if resolved_model is None:
|
||||||
|
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||||
|
checkpoint_dir = model_checkpoint_dir(model_id)
|
||||||
|
else:
|
||||||
|
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||||
|
checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue