From 6d26bbdce35d448c49369e1bc258545a1c6e246d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 21 Oct 2024 18:44:19 -0700 Subject: [PATCH] guided decoding initial draft --- llama_stack/apis/inference/inference.py | 21 +++++ .../meta_reference/inference/generation.py | 80 ++++++++++++++++++- llama_stack/providers/registry/inference.py | 38 ++++----- .../tests/inference/test_inference.py | 16 +++- 4 files changed, 133 insertions(+), 22 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 5895e528e..8dc547b2d 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -74,11 +74,28 @@ class ChatCompletionResponseEvent(BaseModel): stop_reason: Optional[StopReason] = None +class JsonResponseFormat(BaseModel): + type: Literal["json"] = "json" + schema: Dict[str, Any] + + +class GrammarResponseFormat(BaseModel): + type: Literal["grammar"] = "grammar" + bnf: Dict[str, Any] + + +ResponseFormat = Annotated[ + Union[JsonResponseFormat, GrammarResponseFormat], + Field(discriminator="type"), +] + + @json_schema_type class CompletionRequest(BaseModel): model: str content: InterleavedTextMedia sampling_params: Optional[SamplingParams] = SamplingParams() + response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @@ -107,6 +124,7 @@ class BatchCompletionRequest(BaseModel): model: str content_batch: List[InterleavedTextMedia] sampling_params: Optional[SamplingParams] = SamplingParams() + response_format: Optional[ResponseFormat] = None logprobs: Optional[LogProbConfig] = None @@ -129,6 +147,7 @@ class ChatCompletionRequest(BaseModel): tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) + response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @@ -188,6 +207,7 @@ class Inference(Protocol): model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... @@ -204,6 +224,7 @@ class Inference(Protocol): tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 9ca128176..cdf4ec79d 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -12,7 +12,7 @@ import os import sys import time from pathlib import Path -from typing import Generator, List, Optional, Union +from typing import Generator, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -34,6 +34,12 @@ from pydantic import BaseModel from termcolor import cprint from llama_stack.apis.inference import * # noqa: F403 +import math + +from lmformatenforcer import JsonSchemaParser + +from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData + from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, @@ -42,6 +48,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig +class AnswerFormat(BaseModel): + first_name: str + last_name: str + year_of_birth: int + num_seasons_in_nba: int + + def model_checkpoint_dir(model) -> str: checkpoint_dir = Path(model_local_dir(model.descriptor())) @@ -172,9 +185,16 @@ class Llama: include_stop_token: bool = False, print_input_tokens: bool = False, ) -> Generator: + parser = JsonSchemaParser(AnswerFormat.schema()) + tokenizer_data = build_token_enforcer_tokenizer_data( + self.tokenizer, self.args.vocab_size + ) + token_enforcer = TokenEnforcer(tokenizer_data, parser) + logits_processor = LogitsProcessor(token_enforcer) + params = self.model.params - if print_input_tokens: + if print_input_tokens or True: input_tokens = [ self.formatter.vision_token if t == 128256 else t for t in model_input.tokens @@ -246,6 +266,11 @@ class Llama: else: logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + # print(f"{logits=}") + input_ids = tokens[0, :cur_pos].tolist() + # logits = logits_processor.process_logits(input_ids, logits) + # print(f"{logits=}") + if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) @@ -371,3 +396,54 @@ def sample_top_p(probs, p): next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token + + +class LogitsProcessor: + def __init__(self, token_enforcer: TokenEnforcer): + self.token_enforcer = token_enforcer + self.mask: Optional[torch.Tensor] = None + + def process_logits( + self, input_ids: List[int], scores: torch.Tensor + ) -> torch.Tensor: + token_sequence = input_ids + allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence) + # print(f"{allowed_tokens=}") + allowed_tokens = torch.tensor(allowed_tokens, device=scores.device) + if self.mask is not None: + self.mask.fill_(-math.inf) + else: + self.mask = torch.full_like(scores, -math.inf) + + self.mask[:, :, allowed_tokens] = 0 + scores = scores + self.mask + return scores + + +def _build_regular_tokens_list( + tokenizer: Tokenizer, vocab_size: int +) -> List[Tuple[int, str, bool]]: + token_0 = tokenizer.encode("0", bos=False, eos=False)[-1] + regular_tokens = [] + + special_token_ids = set(tokenizer.special_tokens.values()) + for token_idx in range(vocab_size): + if token_idx in special_token_ids: + continue + + # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. + decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:] + decoded_regular = tokenizer.decode([token_idx]) + is_word_start_token = len(decoded_after_0) > len(decoded_regular) + regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) + return regular_tokens + + +def build_token_enforcer_tokenizer_data( + tokenizer: Tokenizer, + vocab_size: int, +) -> TokenEnforcerTokenizerData: + regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size) + return TokenEnforcerTokenizerData( + regular_tokens, tokenizer.decode, tokenizer.stop_tokens + ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 5a09b6af5..6f8bc2c6e 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -9,36 +9,36 @@ from typing import List from llama_stack.distribution.datatypes import * # noqa: F403 +META_REFERENCE_DEPS = [ + "accelerate", + "blobfile", + "fairscale", + "torch", + "torchvision", + "transformers", + "zmq", + "lm-format-enforcer", +] + + def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.inference, provider_type="meta-reference", - pip_packages=[ - "accelerate", - "blobfile", - "fairscale", - "torch", - "torchvision", - "transformers", - "zmq", - ], + pip_packages=META_REFERENCE_DEPS, module="llama_stack.providers.impls.meta_reference.inference", config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig", ), InlineProviderSpec( api=Api.inference, provider_type="meta-reference-quantized", - pip_packages=[ - "accelerate", - "blobfile", - "fairscale", - "fbgemm-gpu==0.8.0", - "torch", - "torchvision", - "transformers", - "zmq", - ], + pip_packages=( + META_REFERENCE_DEPS + + [ + "fbgemm-gpu==0.8.0", + ] + ), module="llama_stack.providers.impls.meta_reference.inference", config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", ), diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index afec9a837..86e37e39c 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -85,11 +85,24 @@ async def inference_settings(request): } +from pydantic import BaseModel + + +class AnswerFormat(BaseModel): + first_name: str + last_name: str + year_of_birth: int + num_seasons_in_nba: int + + @pytest.fixture def sample_messages(): + question = "Please give me information about Michael Jordan. You MUST answer using the following json schema: " + question_with_schema = f"{question}{AnswerFormat.schema_json()}" return [ SystemMessage(content="You are a helpful assistant."), - UserMessage(content="What's the weather like today?"), + # UserMessage(content="What's the weather like today?"), + UserMessage(content=question_with_schema), ] @@ -177,6 +190,7 @@ async def test_chat_completion_non_streaming(inference_settings, sample_messages **inference_settings["common_params"], ) + print(response) assert isinstance(response, ChatCompletionResponse) assert response.completion_message.role == "assistant" assert isinstance(response.completion_message.content, str)