mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 04:22:35 +00:00
add response format to signature
This commit is contained in:
parent
6d26bbdce3
commit
40ba22f4c8
15 changed files with 93 additions and 32 deletions
|
|
@ -80,7 +80,7 @@ class Llama:
|
|||
def build(
|
||||
config: Union[
|
||||
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
]
|
||||
],
|
||||
):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
|
@ -184,17 +184,11 @@ class Llama:
|
|||
echo: bool = False,
|
||||
include_stop_token: bool = False,
|
||||
print_input_tokens: bool = False,
|
||||
logits_processor: Optional["LogitsProcessor"] = None,
|
||||
) -> Generator:
|
||||
parser = JsonSchemaParser(AnswerFormat.schema())
|
||||
tokenizer_data = build_token_enforcer_tokenizer_data(
|
||||
self.tokenizer, self.args.vocab_size
|
||||
)
|
||||
token_enforcer = TokenEnforcer(tokenizer_data, parser)
|
||||
logits_processor = LogitsProcessor(token_enforcer)
|
||||
|
||||
params = self.model.params
|
||||
|
||||
if print_input_tokens or True:
|
||||
if print_input_tokens:
|
||||
input_tokens = [
|
||||
self.formatter.vision_token if t == 128256 else t
|
||||
for t in model_input.tokens
|
||||
|
|
@ -266,10 +260,10 @@ class Llama:
|
|||
else:
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
# print(f"{logits=}")
|
||||
input_ids = tokens[0, :cur_pos].tolist()
|
||||
# logits = logits_processor.process_logits(input_ids, logits)
|
||||
# print(f"{logits=}")
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor.process_logits(
|
||||
tokens[0, :cur_pos].tolist(), logits
|
||||
)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
|
|
@ -342,7 +336,11 @@ class Llama:
|
|||
top_p=sampling_params.top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
|
|
@ -370,6 +368,11 @@ class Llama:
|
|||
top_p=sampling_params.top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -398,6 +401,27 @@ def sample_top_p(probs, p):
|
|||
return next_token
|
||||
|
||||
|
||||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: Optional[ResponseFormat],
|
||||
) -> Optional[LogitsProcessor]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if response_format.type != ResponseFormatType.json:
|
||||
raise ValueError(f"Unsupported response format type {response_format.type}")
|
||||
|
||||
parser = JsonSchemaParser(response_format.schema)
|
||||
data = TokenEnforcerTokenizerData(
|
||||
_build_regular_tokens_list(tokenizer, vocab_size),
|
||||
tokenizer.decode,
|
||||
tokenizer.stop_tokens,
|
||||
)
|
||||
token_enforcer = TokenEnforcer(data, parser)
|
||||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
|
|
@ -437,13 +461,3 @@ def _build_regular_tokens_list(
|
|||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def build_token_enforcer_tokenizer_data(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
) -> TokenEnforcerTokenizerData:
|
||||
regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size)
|
||||
return TokenEnforcerTokenizerData(
|
||||
regular_tokens, tokenizer.decode, tokenizer.stop_tokens
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue