mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
fix test, fix llama3 generator
This commit is contained in:
parent
a3cee70014
commit
771daa4b91
2 changed files with 12 additions and 47 deletions
|
@ -154,7 +154,7 @@ class Llama3:
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
model_inputs: List[LLMInput],
|
llm_inputs: List[LLMInput],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: Optional[int] = None,
|
||||||
|
@ -169,15 +169,15 @@ class Llama3:
|
||||||
|
|
||||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||||
if print_model_input:
|
if print_model_input:
|
||||||
for inp in model_inputs:
|
for inp in llm_inputs:
|
||||||
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
||||||
cprint(
|
cprint(
|
||||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||||
"red",
|
"red",
|
||||||
)
|
)
|
||||||
prompt_tokens = [inp.tokens for inp in model_inputs]
|
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||||
|
|
||||||
bsz = len(model_inputs)
|
bsz = len(llm_inputs)
|
||||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||||
|
|
||||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||||
|
@ -198,8 +198,8 @@ class Llama3:
|
||||||
|
|
||||||
is_vision = not isinstance(self.model, Transformer)
|
is_vision = not isinstance(self.model, Transformer)
|
||||||
if is_vision:
|
if is_vision:
|
||||||
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
|
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
|
||||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
|
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
|
||||||
|
|
||||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||||
batch_images=images,
|
batch_images=images,
|
||||||
|
@ -234,7 +234,7 @@ class Llama3:
|
||||||
for cur_pos in range(min_prompt_len, total_len):
|
for cur_pos in range(min_prompt_len, total_len):
|
||||||
if is_vision:
|
if is_vision:
|
||||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||||
text_only_inference = all(inp.vision is None for inp in model_inputs)
|
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
position_ids,
|
position_ids,
|
||||||
tokens,
|
tokens,
|
||||||
|
|
|
@ -7,53 +7,17 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
|
||||||
|
|
||||||
from ..test_cases.test_case import TestCase
|
from ..test_cases.test_case import TestCase
|
||||||
|
|
||||||
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
|
||||||
|
|
||||||
|
def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id):
|
||||||
def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
|
||||||
models = {m.identifier: m for m in client_with_models.models.list()}
|
models = {m.identifier: m for m in client_with_models.models.list()}
|
||||||
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
||||||
provider_id = models[model_id].provider_id
|
provider_id = models[model_id].provider_id
|
||||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||||
provider = providers[provider_id]
|
provider = providers[provider_id]
|
||||||
if provider.provider_type in (
|
if provider.provider_type not in ("inline::meta-reference",):
|
||||||
"remote::openai",
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference")
|
||||||
"remote::anthropic",
|
|
||||||
"remote::gemini",
|
|
||||||
"remote::groq",
|
|
||||||
"remote::llama-openai-compat",
|
|
||||||
):
|
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
|
||||||
|
|
||||||
|
|
||||||
def get_llama_model(client_with_models, model_id):
|
|
||||||
models = {}
|
|
||||||
for m in client_with_models.models.list():
|
|
||||||
models[m.identifier] = m
|
|
||||||
models[m.provider_resource_id] = m
|
|
||||||
|
|
||||||
assert model_id in models, f"Model {model_id} not found"
|
|
||||||
|
|
||||||
model = models[model_id]
|
|
||||||
ids = (model.identifier, model.provider_resource_id)
|
|
||||||
for mid in ids:
|
|
||||||
if resolve_model(mid):
|
|
||||||
return mid
|
|
||||||
|
|
||||||
return model.metadata.get("llama_model", None)
|
|
||||||
|
|
||||||
|
|
||||||
def get_llama_tokenizer():
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
|
|
||||||
tokenizer = Tokenizer.get_instance()
|
|
||||||
formatter = ChatFormat(tokenizer)
|
|
||||||
return tokenizer, formatter
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -63,7 +27,7 @@ def get_llama_tokenizer():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
|
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
content_batch = tc["contents"]
|
content_batch = tc["contents"]
|
||||||
|
@ -87,6 +51,7 @@ def test_batch_completion_non_streaming(client_with_models, text_model_id, test_
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
qa_pairs = tc["qa_pairs"]
|
qa_pairs = tc["qa_pairs"]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue