mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
guided decoding initial draft
This commit is contained in:
parent
1d241bf3fe
commit
6d26bbdce3
4 changed files with 133 additions and 22 deletions
|
@ -74,11 +74,28 @@ class ChatCompletionResponseEvent(BaseModel):
|
|||
stop_reason: Optional[StopReason] = None
|
||||
|
||||
|
||||
class JsonResponseFormat(BaseModel):
|
||||
type: Literal["json"] = "json"
|
||||
schema: Dict[str, Any]
|
||||
|
||||
|
||||
class GrammarResponseFormat(BaseModel):
|
||||
type: Literal["grammar"] = "grammar"
|
||||
bnf: Dict[str, Any]
|
||||
|
||||
|
||||
ResponseFormat = Annotated[
|
||||
Union[JsonResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
content: InterleavedTextMedia
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
@ -107,6 +124,7 @@ class BatchCompletionRequest(BaseModel):
|
|||
model: str
|
||||
content_batch: List[InterleavedTextMedia]
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
||||
|
||||
|
@ -129,6 +147,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
@ -188,6 +207,7 @@ class Inference(Protocol):
|
|||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||
|
@ -204,6 +224,7 @@ class Inference(Protocol):
|
|||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||
|
|
|
@ -12,7 +12,7 @@ import os
|
|||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional, Union
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -34,6 +34,12 @@ from pydantic import BaseModel
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
import math
|
||||
|
||||
from lmformatenforcer import JsonSchemaParser
|
||||
|
||||
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
|
@ -42,6 +48,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
|
@ -172,9 +185,16 @@ class Llama:
|
|||
include_stop_token: bool = False,
|
||||
print_input_tokens: bool = False,
|
||||
) -> Generator:
|
||||
parser = JsonSchemaParser(AnswerFormat.schema())
|
||||
tokenizer_data = build_token_enforcer_tokenizer_data(
|
||||
self.tokenizer, self.args.vocab_size
|
||||
)
|
||||
token_enforcer = TokenEnforcer(tokenizer_data, parser)
|
||||
logits_processor = LogitsProcessor(token_enforcer)
|
||||
|
||||
params = self.model.params
|
||||
|
||||
if print_input_tokens:
|
||||
if print_input_tokens or True:
|
||||
input_tokens = [
|
||||
self.formatter.vision_token if t == 128256 else t
|
||||
for t in model_input.tokens
|
||||
|
@ -246,6 +266,11 @@ class Llama:
|
|||
else:
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
# print(f"{logits=}")
|
||||
input_ids = tokens[0, :cur_pos].tolist()
|
||||
# logits = logits_processor.process_logits(input_ids, logits)
|
||||
# print(f"{logits=}")
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
|
@ -371,3 +396,54 @@ def sample_top_p(probs, p):
|
|||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def process_logits(
|
||||
self, input_ids: List[int], scores: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
token_sequence = input_ids
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
# print(f"{allowed_tokens=}")
|
||||
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def _build_regular_tokens_list(
|
||||
tokenizer: Tokenizer, vocab_size: int
|
||||
) -> List[Tuple[int, str, bool]]:
|
||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||
regular_tokens = []
|
||||
|
||||
special_token_ids = set(tokenizer.special_tokens.values())
|
||||
for token_idx in range(vocab_size):
|
||||
if token_idx in special_token_ids:
|
||||
continue
|
||||
|
||||
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
||||
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
||||
decoded_regular = tokenizer.decode([token_idx])
|
||||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def build_token_enforcer_tokenizer_data(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
) -> TokenEnforcerTokenizerData:
|
||||
regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size)
|
||||
return TokenEnforcerTokenizerData(
|
||||
regular_tokens, tokenizer.decode, tokenizer.stop_tokens
|
||||
)
|
||||
|
|
|
@ -9,36 +9,36 @@ from typing import List
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
META_REFERENCE_DEPS = [
|
||||
"accelerate",
|
||||
"blobfile",
|
||||
"fairscale",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"lm-format-enforcer",
|
||||
]
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[
|
||||
"accelerate",
|
||||
"blobfile",
|
||||
"fairscale",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"zmq",
|
||||
],
|
||||
pip_packages=META_REFERENCE_DEPS,
|
||||
module="llama_stack.providers.impls.meta_reference.inference",
|
||||
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="meta-reference-quantized",
|
||||
pip_packages=[
|
||||
"accelerate",
|
||||
"blobfile",
|
||||
"fairscale",
|
||||
"fbgemm-gpu==0.8.0",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"zmq",
|
||||
],
|
||||
pip_packages=(
|
||||
META_REFERENCE_DEPS
|
||||
+ [
|
||||
"fbgemm-gpu==0.8.0",
|
||||
]
|
||||
),
|
||||
module="llama_stack.providers.impls.meta_reference.inference",
|
||||
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
|
||||
),
|
||||
|
|
|
@ -85,11 +85,24 @@ async def inference_settings(request):
|
|||
}
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
question = "Please give me information about Michael Jordan. You MUST answer using the following json schema: "
|
||||
question_with_schema = f"{question}{AnswerFormat.schema_json()}"
|
||||
return [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
# UserMessage(content="What's the weather like today?"),
|
||||
UserMessage(content=question_with_schema),
|
||||
]
|
||||
|
||||
|
||||
|
@ -177,6 +190,7 @@ async def test_chat_completion_non_streaming(inference_settings, sample_messages
|
|||
**inference_settings["common_params"],
|
||||
)
|
||||
|
||||
print(response)
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue