guided decoding initial draft

This commit is contained in:
Ashwin Bharambe 2024-10-21 18:44:19 -07:00
parent 1d241bf3fe
commit 6d26bbdce3
4 changed files with 133 additions and 22 deletions

View file

@ -12,7 +12,7 @@ import os
import sys
import time
from pathlib import Path
from typing import Generator, List, Optional, Union
from typing import Generator, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -34,6 +34,12 @@ from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
import math
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
@ -42,6 +48,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -172,9 +185,16 @@ class Llama:
include_stop_token: bool = False,
print_input_tokens: bool = False,
) -> Generator:
parser = JsonSchemaParser(AnswerFormat.schema())
tokenizer_data = build_token_enforcer_tokenizer_data(
self.tokenizer, self.args.vocab_size
)
token_enforcer = TokenEnforcer(tokenizer_data, parser)
logits_processor = LogitsProcessor(token_enforcer)
params = self.model.params
if print_input_tokens:
if print_input_tokens or True:
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
@ -246,6 +266,11 @@ class Llama:
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
# print(f"{logits=}")
input_ids = tokens[0, :cur_pos].tolist()
# logits = logits_processor.process_logits(input_ids, logits)
# print(f"{logits=}")
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
@ -371,3 +396,54 @@ def sample_top_p(probs, p):
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
token_sequence = input_ids
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
# print(f"{allowed_tokens=}")
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def _build_regular_tokens_list(
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []
special_token_ids = set(tokenizer.special_tokens.values())
for token_idx in range(vocab_size):
if token_idx in special_token_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
decoded_regular = tokenizer.decode([token_idx])
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens
def build_token_enforcer_tokenizer_data(
tokenizer: Tokenizer,
vocab_size: int,
) -> TokenEnforcerTokenizerData:
regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size)
return TokenEnforcerTokenizerData(
regular_tokens, tokenizer.decode, tokenizer.stop_tokens
)