guided decoding initial draft

This commit is contained in:
Ashwin Bharambe 2024-10-21 18:44:19 -07:00
parent 1d241bf3fe
commit 6d26bbdce3
4 changed files with 133 additions and 22 deletions

View file

@ -74,11 +74,28 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: Optional[StopReason] = None
class JsonResponseFormat(BaseModel):
type: Literal["json"] = "json"
schema: Dict[str, Any]
class GrammarResponseFormat(BaseModel):
type: Literal["grammar"] = "grammar"
bnf: Dict[str, Any]
ResponseFormat = Annotated[
Union[JsonResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
@json_schema_type
class CompletionRequest(BaseModel):
model: str
content: InterleavedTextMedia
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -107,6 +124,7 @@ class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
logprobs: Optional[LogProbConfig] = None
@ -129,6 +147,7 @@ class ChatCompletionRequest(BaseModel):
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -188,6 +207,7 @@ class Inference(Protocol):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@ -204,6 +224,7 @@ class Inference(Protocol):
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...

View file

@ -12,7 +12,7 @@ import os
import sys
import time
from pathlib import Path
from typing import Generator, List, Optional, Union
from typing import Generator, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -34,6 +34,12 @@ from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
import math
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
@ -42,6 +48,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -172,9 +185,16 @@ class Llama:
include_stop_token: bool = False,
print_input_tokens: bool = False,
) -> Generator:
parser = JsonSchemaParser(AnswerFormat.schema())
tokenizer_data = build_token_enforcer_tokenizer_data(
self.tokenizer, self.args.vocab_size
)
token_enforcer = TokenEnforcer(tokenizer_data, parser)
logits_processor = LogitsProcessor(token_enforcer)
params = self.model.params
if print_input_tokens:
if print_input_tokens or True:
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
@ -246,6 +266,11 @@ class Llama:
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
# print(f"{logits=}")
input_ids = tokens[0, :cur_pos].tolist()
# logits = logits_processor.process_logits(input_ids, logits)
# print(f"{logits=}")
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
@ -371,3 +396,54 @@ def sample_top_p(probs, p):
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
token_sequence = input_ids
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
# print(f"{allowed_tokens=}")
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def _build_regular_tokens_list(
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []
special_token_ids = set(tokenizer.special_tokens.values())
for token_idx in range(vocab_size):
if token_idx in special_token_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
decoded_regular = tokenizer.decode([token_idx])
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens
def build_token_enforcer_tokenizer_data(
tokenizer: Tokenizer,
vocab_size: int,
) -> TokenEnforcerTokenizerData:
regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size)
return TokenEnforcerTokenizerData(
regular_tokens, tokenizer.decode, tokenizer.stop_tokens
)

View file

@ -9,36 +9,36 @@ from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
META_REFERENCE_DEPS = [
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
"lm-format-enforcer",
]
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
],
pip_packages=META_REFERENCE_DEPS,
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference-quantized",
pip_packages=[
"accelerate",
"blobfile",
"fairscale",
"fbgemm-gpu==0.8.0",
"torch",
"torchvision",
"transformers",
"zmq",
],
pip_packages=(
META_REFERENCE_DEPS
+ [
"fbgemm-gpu==0.8.0",
]
),
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
),

View file

@ -85,11 +85,24 @@ async def inference_settings(request):
}
from pydantic import BaseModel
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
@pytest.fixture
def sample_messages():
question = "Please give me information about Michael Jordan. You MUST answer using the following json schema: "
question_with_schema = f"{question}{AnswerFormat.schema_json()}"
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
# UserMessage(content="What's the weather like today?"),
UserMessage(content=question_with_schema),
]
@ -177,6 +190,7 @@ async def test_chat_completion_non_streaming(inference_settings, sample_messages
**inference_settings["common_params"],
)
print(response)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)