refactor: move all llama code to models/llama out of meta reference (#1887)

# What does this PR do?

Move around bits. This makes the copies from llama-models _much_ easier
to maintain and ensures we don't entangle meta-reference specific
tidbits into llama-models code even by accident.

Also, kills the meta-reference-quantized-gpu distro and rolls
quantization deps into meta-reference-gpu.

## Test Plan

```
LLAMA_MODELS_DEBUG=1 \
  with-proxy llama stack run meta-reference-gpu \
  --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \
   --env INFERENCE_CHECKPOINT_DIR=<DIR> \
   --env MODEL_PARALLEL_SIZE=4 \
   --env QUANTIZATION_TYPE=fp8_mixed
```

Start a server with and without quantization. Point integration tests to
it using:

```
pytest -s -v  tests/integration/inference/test_text_inference.py \
   --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
This commit is contained in:
Ashwin Bharambe 2025-04-07 15:03:58 -07:00 committed by GitHub
parent c52ccc4bbd
commit 530d4bdfe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
85 changed files with 1267 additions and 1683 deletions

View file

@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel):
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
model_parallel_size: Optional[int] = None
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
@ -31,6 +32,8 @@ class MetaReferenceInferenceConfig(BaseModel):
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None
quantization: Optional[QuantizationConfig] = None
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
@ -47,27 +50,16 @@ class MetaReferenceInferenceConfig(BaseModel):
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
**kwargs,
) -> Dict[str, Any]:
return {
"model": model,
"max_seq_len": 4096,
"checkpoint_dir": checkpoint_dir,
"quantization": {
"type": quantization_type,
},
"model_parallel_size": model_parallel_size,
}
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig
@classmethod
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
**kwargs,
) -> Dict[str, Any]:
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
config["quantization"] = {
"type": "fp8",
}
return config