mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
fix test, fix llama3 generator
This commit is contained in:
parent
a3cee70014
commit
771daa4b91
2 changed files with 12 additions and 47 deletions
|
@ -154,7 +154,7 @@ class Llama3:
|
|||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_inputs: List[LLMInput],
|
||||
llm_inputs: List[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
|
@ -169,15 +169,15 @@ class Llama3:
|
|||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input:
|
||||
for inp in model_inputs:
|
||||
for inp in llm_inputs:
|
||||
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
||||
cprint(
|
||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||
"red",
|
||||
)
|
||||
prompt_tokens = [inp.tokens for inp in model_inputs]
|
||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||
|
||||
bsz = len(model_inputs)
|
||||
bsz = len(llm_inputs)
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
|
@ -198,8 +198,8 @@ class Llama3:
|
|||
|
||||
is_vision = not isinstance(self.model, Transformer)
|
||||
if is_vision:
|
||||
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
|
||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
|
||||
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
|
||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
|
||||
|
||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||
batch_images=images,
|
||||
|
@ -234,7 +234,7 @@ class Llama3:
|
|||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||
text_only_inference = all(inp.vision is None for inp in model_inputs)
|
||||
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
|
|
|
@ -7,53 +7,17 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
from ..test_cases.test_case import TestCase
|
||||
|
||||
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
||||
def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id):
|
||||
models = {m.identifier: m for m in client_with_models.models.list()}
|
||||
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
||||
provider_id = models[model_id].provider_id
|
||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||
provider = providers[provider_id]
|
||||
if provider.provider_type in (
|
||||
"remote::openai",
|
||||
"remote::anthropic",
|
||||
"remote::gemini",
|
||||
"remote::groq",
|
||||
"remote::llama-openai-compat",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||
|
||||
|
||||
def get_llama_model(client_with_models, model_id):
|
||||
models = {}
|
||||
for m in client_with_models.models.list():
|
||||
models[m.identifier] = m
|
||||
models[m.provider_resource_id] = m
|
||||
|
||||
assert model_id in models, f"Model {model_id} not found"
|
||||
|
||||
model = models[model_id]
|
||||
ids = (model.identifier, model.provider_resource_id)
|
||||
for mid in ids:
|
||||
if resolve_model(mid):
|
||||
return mid
|
||||
|
||||
return model.metadata.get("llama_model", None)
|
||||
|
||||
|
||||
def get_llama_tokenizer():
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
formatter = ChatFormat(tokenizer)
|
||||
return tokenizer, formatter
|
||||
if provider.provider_type not in ("inline::meta-reference",):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -63,7 +27,7 @@ def get_llama_tokenizer():
|
|||
],
|
||||
)
|
||||
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
content_batch = tc["contents"]
|
||||
|
@ -87,6 +51,7 @@ def test_batch_completion_non_streaming(client_with_models, text_model_id, test_
|
|||
],
|
||||
)
|
||||
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
qa_pairs = tc["qa_pairs"]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue