Fix and add a test

This commit is contained in:
Ashwin Bharambe 2024-10-21 22:02:37 -07:00
parent 40ba22f4c8
commit cd84dee3e9
5 changed files with 76 additions and 62 deletions

View file

@ -8,6 +8,7 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import math
import os
import sys
import time
@ -34,11 +35,8 @@ from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
import math
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -48,13 +46,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -261,9 +252,7 @@ class Llama:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logits_processor is not None:
logits = logits_processor.process_logits(
tokens[0, :cur_pos].tolist(), logits
)
logits = logits_processor.process_logits(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
@ -401,15 +390,36 @@ def sample_top_p(probs, p):
return next_token
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, tokens: torch.Tensor, scores: torch.Tensor
) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def get_logits_processor(
tokenizer: Tokenizer,
vocab_size: int,
response_format: Optional[ResponseFormat],
) -> Optional[LogitsProcessor]:
) -> Optional["LogitsProcessor"]:
if response_format is None:
return None
if response_format.type != ResponseFormatType.json:
if response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Unsupported response format type {response_format.type}")
parser = JsonSchemaParser(response_format.schema)
@ -422,28 +432,6 @@ def get_logits_processor(
return LogitsProcessor(token_enforcer)
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
token_sequence = input_ids
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
# print(f"{allowed_tokens=}")
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def _build_regular_tokens_list(
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]: