mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 13:02:38 +00:00
Fix and add a test
This commit is contained in:
parent
40ba22f4c8
commit
cd84dee3e9
5 changed files with 76 additions and 62 deletions
|
|
@ -8,6 +8,7 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
|
@ -34,11 +35,8 @@ from pydantic import BaseModel
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
import math
|
||||
|
||||
from lmformatenforcer import JsonSchemaParser
|
||||
|
||||
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
|
@ -48,13 +46,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
|
|
@ -261,9 +252,7 @@ class Llama:
|
|||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor.process_logits(
|
||||
tokens[0, :cur_pos].tolist(), logits
|
||||
)
|
||||
logits = logits_processor.process_logits(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
|
|
@ -401,15 +390,36 @@ def sample_top_p(probs, p):
|
|||
return next_token
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def process_logits(
|
||||
self, tokens: torch.Tensor, scores: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
token_sequence = tokens[0, :].tolist()
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: Optional[ResponseFormat],
|
||||
) -> Optional[LogitsProcessor]:
|
||||
) -> Optional["LogitsProcessor"]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if response_format.type != ResponseFormatType.json:
|
||||
if response_format.type != ResponseFormatType.json_schema.value:
|
||||
raise ValueError(f"Unsupported response format type {response_format.type}")
|
||||
|
||||
parser = JsonSchemaParser(response_format.schema)
|
||||
|
|
@ -422,28 +432,6 @@ def get_logits_processor(
|
|||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def process_logits(
|
||||
self, input_ids: List[int], scores: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
token_sequence = input_ids
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
# print(f"{allowed_tokens=}")
|
||||
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def _build_regular_tokens_list(
|
||||
tokenizer: Tokenizer, vocab_size: int
|
||||
) -> List[Tuple[int, str, bool]]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue