mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
bugfixes
This commit is contained in:
parent
37ca22cda6
commit
23028e26ff
7 changed files with 83 additions and 47 deletions
|
@ -7,12 +7,13 @@
|
|||
from typing import Optional
|
||||
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from llama_models.sku_list import all_registered_models, resolve_model
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str = Field(
|
||||
|
@ -27,12 +28,7 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in all_registered_models()
|
||||
if m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
or m.core_model_id == CoreModelId.llama_guard_3_8b
|
||||
]
|
||||
permitted_models = supported_inference_models()
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
raise ValueError(
|
||||
|
|
|
@ -52,7 +52,7 @@ def model_checkpoint_dir(model) -> str:
|
|||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoint dir: {checkpoint_dir}."
|
||||
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
||||
f"Please download model using `llama download --model-id {model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
@ -185,11 +185,11 @@ class Llama:
|
|||
) -> Generator:
|
||||
params = self.model.params
|
||||
|
||||
# input_tokens = [
|
||||
# self.formatter.vision_token if t == 128256 else t
|
||||
# for t in model_input.tokens
|
||||
# ]
|
||||
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
||||
input_tokens = [
|
||||
self.formatter.vision_token if t == 128256 else t
|
||||
for t in model_input.tokens
|
||||
]
|
||||
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
||||
prompt_tokens = [model_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
|
@ -207,6 +207,7 @@ class Llama:
|
|||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
is_vision = isinstance(self.model, CrossAttentionTransformer)
|
||||
print(f"{is_vision=}")
|
||||
if is_vision:
|
||||
images = model_input.vision.images if model_input.vision is not None else []
|
||||
mask = model_input.vision.mask if model_input.vision is not None else []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue