refactor: move all llama code to models/llama out of meta reference (#1887)

# What does this PR do?

Move around bits. This makes the copies from llama-models _much_ easier
to maintain and ensures we don't entangle meta-reference specific
tidbits into llama-models code even by accident.

Also, kills the meta-reference-quantized-gpu distro and rolls
quantization deps into meta-reference-gpu.

## Test Plan

```
LLAMA_MODELS_DEBUG=1 \
  with-proxy llama stack run meta-reference-gpu \
  --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \
   --env INFERENCE_CHECKPOINT_DIR=<DIR> \
   --env MODEL_PARALLEL_SIZE=4 \
   --env QUANTIZATION_TYPE=fp8_mixed
```

Start a server with and without quantization. Point integration tests to
it using:

```
pytest -s -v  tests/integration/inference/test_text_inference.py \
   --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
This commit is contained in:
Ashwin Bharambe 2025-04-07 15:03:58 -07:00 committed by GitHub
parent c52ccc4bbd
commit 530d4bdfe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
85 changed files with 1267 additions and 1683 deletions

View file

@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import (
ToolPromptFormat,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
LLMInput,
)
from .llama3.interface import LLama31Interface
from .llama3.template_data import (
@ -76,21 +73,22 @@ class UseCase(BaseModel):
text += dialog
text += "\n\n"
continue
elif isinstance(dialog, TextCompletionContent):
input_tokens, output_tokens = generator.text_completion_raw(
dialog.content,
temperature=0.1,
top_p=0.95,
max_gen_len=64,
)
else:
input_tokens, output_tokens = generator.chat_completion_raw(
dialog,
temperature=0.0,
top_p=0.95,
max_gen_len=self.max_gen_len,
batch = [dialog]
method = (
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
)
input_tokens = []
output_tokens = []
for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95):
result = token_results[0]
if result.source == "input":
input_tokens.append(result.token)
else:
output_tokens.append(result.token)
if result.finished:
break
text += "##### Input Prompt Format\n"
# FIXME: This is added to undo the hack in chat_formatter where
@ -126,27 +124,27 @@ class Llama4UseCase(UseCase):
text = ""
tokenizer = Tokenizer.get_instance()
temperature = 0.0
for dialog in self.dialogs:
if isinstance(dialog, str):
text += dialog
text += "\n\n"
continue
elif isinstance(dialog, TextCompletionContent):
# TODO pass the raw input and do the encoding in the text completion function
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
llm_input = LLMInput(tokens=input_tokens)
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
)
else:
input_tokens, output_tokens = generator.chat_completion_raw(
dialog,
temperature=temperature,
max_gen_len=self.max_gen_len,
batch = [dialog]
method = (
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
)
input_tokens = []
output_tokens = []
for token_results in method(batch, echo=True, temperature=0.0):
result = token_results[0]
if result.source == "input":
input_tokens.append(result.token)
else:
output_tokens.append(result.token)
if result.finished:
break
text += "##### Input Prompt Format\n"
text += _code_block(tokenizer.decode(input_tokens))