mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 18:02:37 +00:00
guided decoding initial draft
This commit is contained in:
parent
1d241bf3fe
commit
6d26bbdce3
4 changed files with 133 additions and 22 deletions
|
|
@ -12,7 +12,7 @@ import os
|
|||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional, Union
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -34,6 +34,12 @@ from pydantic import BaseModel
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
import math
|
||||
|
||||
from lmformatenforcer import JsonSchemaParser
|
||||
|
||||
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
|
|
@ -42,6 +48,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
|
|
@ -172,9 +185,16 @@ class Llama:
|
|||
include_stop_token: bool = False,
|
||||
print_input_tokens: bool = False,
|
||||
) -> Generator:
|
||||
parser = JsonSchemaParser(AnswerFormat.schema())
|
||||
tokenizer_data = build_token_enforcer_tokenizer_data(
|
||||
self.tokenizer, self.args.vocab_size
|
||||
)
|
||||
token_enforcer = TokenEnforcer(tokenizer_data, parser)
|
||||
logits_processor = LogitsProcessor(token_enforcer)
|
||||
|
||||
params = self.model.params
|
||||
|
||||
if print_input_tokens:
|
||||
if print_input_tokens or True:
|
||||
input_tokens = [
|
||||
self.formatter.vision_token if t == 128256 else t
|
||||
for t in model_input.tokens
|
||||
|
|
@ -246,6 +266,11 @@ class Llama:
|
|||
else:
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
# print(f"{logits=}")
|
||||
input_ids = tokens[0, :cur_pos].tolist()
|
||||
# logits = logits_processor.process_logits(input_ids, logits)
|
||||
# print(f"{logits=}")
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
|
|
@ -371,3 +396,54 @@ def sample_top_p(probs, p):
|
|||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def process_logits(
|
||||
self, input_ids: List[int], scores: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
token_sequence = input_ids
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
# print(f"{allowed_tokens=}")
|
||||
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def _build_regular_tokens_list(
|
||||
tokenizer: Tokenizer, vocab_size: int
|
||||
) -> List[Tuple[int, str, bool]]:
|
||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||
regular_tokens = []
|
||||
|
||||
special_token_ids = set(tokenizer.special_tokens.values())
|
||||
for token_idx in range(vocab_size):
|
||||
if token_idx in special_token_ids:
|
||||
continue
|
||||
|
||||
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
||||
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
||||
decoded_regular = tokenizer.decode([token_idx])
|
||||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def build_token_enforcer_tokenizer_data(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
) -> TokenEnforcerTokenizerData:
|
||||
regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size)
|
||||
return TokenEnforcerTokenizerData(
|
||||
regular_tokens, tokenizer.decode, tokenizer.stop_tokens
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue