From 771daa4b911a52860a439457f0f553256b9da573 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 12 Apr 2025 10:51:43 -0700 Subject: [PATCH] fix test, fix llama3 generator --- llama_stack/models/llama/llama3/generation.py | 14 +++--- .../inference/test_batch_inference.py | 45 +++---------------- 2 files changed, 12 insertions(+), 47 deletions(-) diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py index 98412a1d4..35c140707 100644 --- a/llama_stack/models/llama/llama3/generation.py +++ b/llama_stack/models/llama/llama3/generation.py @@ -154,7 +154,7 @@ class Llama3: @torch.inference_mode() def generate( self, - model_inputs: List[LLMInput], + llm_inputs: List[LLMInput], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, @@ -169,15 +169,15 @@ class Llama3: print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" if print_model_input: - for inp in model_inputs: + for inp in llm_inputs: tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens] cprint( "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", "red", ) - prompt_tokens = [inp.tokens for inp in model_inputs] + prompt_tokens = [inp.tokens for inp in llm_inputs] - bsz = len(model_inputs) + bsz = len(llm_inputs) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) min_prompt_len = min(len(t) for t in prompt_tokens) @@ -198,8 +198,8 @@ class Llama3: is_vision = not isinstance(self.model, Transformer) if is_vision: - images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs] - mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs] + images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs] + mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs] xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( batch_images=images, @@ -234,7 +234,7 @@ class Llama3: for cur_pos in range(min_prompt_len, total_len): if is_vision: position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) - text_only_inference = all(inp.vision is None for inp in model_inputs) + text_only_inference = all(inp.vision is None for inp in llm_inputs) logits = self.model.forward( position_ids, tokens, diff --git a/tests/integration/inference/test_batch_inference.py b/tests/integration/inference/test_batch_inference.py index f2bbd0698..9a1a62ce0 100644 --- a/tests/integration/inference/test_batch_inference.py +++ b/tests/integration/inference/test_batch_inference.py @@ -7,53 +7,17 @@ import pytest -from llama_stack.models.llama.sku_list import resolve_model - from ..test_cases.test_case import TestCase -PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"} - -def skip_if_model_doesnt_support_completion(client_with_models, model_id): +def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id): models = {m.identifier: m for m in client_with_models.models.list()} models.update({m.provider_resource_id: m for m in client_with_models.models.list()}) provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} provider = providers[provider_id] - if provider.provider_type in ( - "remote::openai", - "remote::anthropic", - "remote::gemini", - "remote::groq", - "remote::llama-openai-compat", - ): - pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") - - -def get_llama_model(client_with_models, model_id): - models = {} - for m in client_with_models.models.list(): - models[m.identifier] = m - models[m.provider_resource_id] = m - - assert model_id in models, f"Model {model_id} not found" - - model = models[model_id] - ids = (model.identifier, model.provider_resource_id) - for mid in ids: - if resolve_model(mid): - return mid - - return model.metadata.get("llama_model", None) - - -def get_llama_tokenizer(): - from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.tokenizer import Tokenizer - - tokenizer = Tokenizer.get_instance() - formatter = ChatFormat(tokenizer) - return tokenizer, formatter + if provider.provider_type not in ("inline::meta-reference",): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference") @pytest.mark.parametrize( @@ -63,7 +27,7 @@ def get_llama_tokenizer(): ], ) def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case): - skip_if_model_doesnt_support_completion(client_with_models, text_model_id) + skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id) tc = TestCase(test_case) content_batch = tc["contents"] @@ -87,6 +51,7 @@ def test_batch_completion_non_streaming(client_with_models, text_model_id, test_ ], ) def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case): + skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id) tc = TestCase(test_case) qa_pairs = tc["qa_pairs"]