This commit is contained in:
Ashwin Bharambe 2024-09-28 15:21:32 -07:00
parent 37ca22cda6
commit 23028e26ff
7 changed files with 83 additions and 47 deletions

View file

@ -7,12 +7,13 @@
from typing import Optional
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import all_registered_models, resolve_model
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceImplConfig(BaseModel):
model: str = Field(
@ -27,12 +28,7 @@ class MetaReferenceImplConfig(BaseModel):
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in all_registered_models()
if m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
or m.core_model_id == CoreModelId.llama_guard_3_8b
]
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
raise ValueError(

View file

@ -52,7 +52,7 @@ def model_checkpoint_dir(model) -> str:
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoint dir: {checkpoint_dir}."
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
)
return str(checkpoint_dir)
@ -185,11 +185,11 @@ class Llama:
) -> Generator:
params = self.model.params
# input_tokens = [
# self.formatter.vision_token if t == 128256 else t
# for t in model_input.tokens
# ]
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
@ -207,6 +207,7 @@ class Llama:
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = isinstance(self.model, CrossAttentionTransformer)
print(f"{is_vision=}")
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []