guided decoding initial draft

This commit is contained in:
Ashwin Bharambe 2024-10-21 18:44:19 -07:00
parent 1d241bf3fe
commit 6d26bbdce3
4 changed files with 133 additions and 22 deletions

View file

@ -74,11 +74,28 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: Optional[StopReason] = None stop_reason: Optional[StopReason] = None
class JsonResponseFormat(BaseModel):
type: Literal["json"] = "json"
schema: Dict[str, Any]
class GrammarResponseFormat(BaseModel):
type: Literal["grammar"] = "grammar"
bnf: Dict[str, Any]
ResponseFormat = Annotated[
Union[JsonResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
@json_schema_type @json_schema_type
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
content: InterleavedTextMedia content: InterleavedTextMedia
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@ -107,6 +124,7 @@ class BatchCompletionRequest(BaseModel):
model: str model: str
content_batch: List[InterleavedTextMedia] content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@ -129,6 +147,7 @@ class ChatCompletionRequest(BaseModel):
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json default=ToolPromptFormat.json
) )
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@ -188,6 +207,7 @@ class Inference(Protocol):
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@ -204,6 +224,7 @@ class Inference(Protocol):
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...

View file

@ -12,7 +12,7 @@ import os
import sys import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import Generator, List, Optional, Union from typing import Generator, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -34,6 +34,12 @@ from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
import math
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages, chat_completion_request_to_messages,
@ -42,6 +48,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor())) checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -172,9 +185,16 @@ class Llama:
include_stop_token: bool = False, include_stop_token: bool = False,
print_input_tokens: bool = False, print_input_tokens: bool = False,
) -> Generator: ) -> Generator:
parser = JsonSchemaParser(AnswerFormat.schema())
tokenizer_data = build_token_enforcer_tokenizer_data(
self.tokenizer, self.args.vocab_size
)
token_enforcer = TokenEnforcer(tokenizer_data, parser)
logits_processor = LogitsProcessor(token_enforcer)
params = self.model.params params = self.model.params
if print_input_tokens: if print_input_tokens or True:
input_tokens = [ input_tokens = [
self.formatter.vision_token if t == 128256 else t self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens for t in model_input.tokens
@ -246,6 +266,11 @@ class Llama:
else: else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
# print(f"{logits=}")
input_ids = tokens[0, :cur_pos].tolist()
# logits = logits_processor.process_logits(input_ids, logits)
# print(f"{logits=}")
if temperature > 0: if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1) probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p) next_token = sample_top_p(probs, top_p)
@ -371,3 +396,54 @@ def sample_top_p(probs, p):
next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token) next_token = torch.gather(probs_idx, -1, next_token)
return next_token return next_token
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
token_sequence = input_ids
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
# print(f"{allowed_tokens=}")
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def _build_regular_tokens_list(
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []
special_token_ids = set(tokenizer.special_tokens.values())
for token_idx in range(vocab_size):
if token_idx in special_token_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
decoded_regular = tokenizer.decode([token_idx])
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens
def build_token_enforcer_tokenizer_data(
tokenizer: Tokenizer,
vocab_size: int,
) -> TokenEnforcerTokenizerData:
regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size)
return TokenEnforcerTokenizerData(
regular_tokens, tokenizer.decode, tokenizer.stop_tokens
)

View file

@ -9,36 +9,36 @@ from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
META_REFERENCE_DEPS = [
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
"lm-format-enforcer",
]
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:
return [ return [
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,
provider_type="meta-reference", provider_type="meta-reference",
pip_packages=[ pip_packages=META_REFERENCE_DEPS,
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference", module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig", config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,
provider_type="meta-reference-quantized", provider_type="meta-reference-quantized",
pip_packages=[ pip_packages=(
"accelerate", META_REFERENCE_DEPS
"blobfile", + [
"fairscale", "fbgemm-gpu==0.8.0",
"fbgemm-gpu==0.8.0", ]
"torch", ),
"torchvision",
"transformers",
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference", module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
), ),

View file

@ -85,11 +85,24 @@ async def inference_settings(request):
} }
from pydantic import BaseModel
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
@pytest.fixture @pytest.fixture
def sample_messages(): def sample_messages():
question = "Please give me information about Michael Jordan. You MUST answer using the following json schema: "
question_with_schema = f"{question}{AnswerFormat.schema_json()}"
return [ return [
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"), # UserMessage(content="What's the weather like today?"),
UserMessage(content=question_with_schema),
] ]
@ -177,6 +190,7 @@ async def test_chat_completion_non_streaming(inference_settings, sample_messages
**inference_settings["common_params"], **inference_settings["common_params"],
) )
print(response)
assert isinstance(response, ChatCompletionResponse) assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant" assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str) assert isinstance(response.completion_message.content, str)