mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 23:29:43 +00:00
guided decoding initial draft
This commit is contained in:
parent
1d241bf3fe
commit
6d26bbdce3
4 changed files with 133 additions and 22 deletions
|
@ -74,11 +74,28 @@ class ChatCompletionResponseEvent(BaseModel):
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: Optional[StopReason] = None
|
||||||
|
|
||||||
|
|
||||||
|
class JsonResponseFormat(BaseModel):
|
||||||
|
type: Literal["json"] = "json"
|
||||||
|
schema: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class GrammarResponseFormat(BaseModel):
|
||||||
|
type: Literal["grammar"] = "grammar"
|
||||||
|
bnf: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
ResponseFormat = Annotated[
|
||||||
|
Union[JsonResponseFormat, GrammarResponseFormat],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content: InterleavedTextMedia
|
content: InterleavedTextMedia
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
@ -107,6 +124,7 @@ class BatchCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content_batch: List[InterleavedTextMedia]
|
content_batch: List[InterleavedTextMedia]
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
response_format: Optional[ResponseFormat] = None
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,6 +147,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
)
|
)
|
||||||
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
@ -188,6 +207,7 @@ class Inference(Protocol):
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||||
|
@ -204,6 +224,7 @@ class Inference(Protocol):
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||||
|
|
|
@ -12,7 +12,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, List, Optional, Union
|
from typing import Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -34,6 +34,12 @@ from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
import math
|
||||||
|
|
||||||
|
from lmformatenforcer import JsonSchemaParser
|
||||||
|
|
||||||
|
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_messages,
|
chat_completion_request_to_messages,
|
||||||
|
@ -42,6 +48,13 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerFormat(BaseModel):
|
||||||
|
first_name: str
|
||||||
|
last_name: str
|
||||||
|
year_of_birth: int
|
||||||
|
num_seasons_in_nba: int
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||||
|
|
||||||
|
@ -172,9 +185,16 @@ class Llama:
|
||||||
include_stop_token: bool = False,
|
include_stop_token: bool = False,
|
||||||
print_input_tokens: bool = False,
|
print_input_tokens: bool = False,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
|
parser = JsonSchemaParser(AnswerFormat.schema())
|
||||||
|
tokenizer_data = build_token_enforcer_tokenizer_data(
|
||||||
|
self.tokenizer, self.args.vocab_size
|
||||||
|
)
|
||||||
|
token_enforcer = TokenEnforcer(tokenizer_data, parser)
|
||||||
|
logits_processor = LogitsProcessor(token_enforcer)
|
||||||
|
|
||||||
params = self.model.params
|
params = self.model.params
|
||||||
|
|
||||||
if print_input_tokens:
|
if print_input_tokens or True:
|
||||||
input_tokens = [
|
input_tokens = [
|
||||||
self.formatter.vision_token if t == 128256 else t
|
self.formatter.vision_token if t == 128256 else t
|
||||||
for t in model_input.tokens
|
for t in model_input.tokens
|
||||||
|
@ -246,6 +266,11 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
|
|
||||||
|
# print(f"{logits=}")
|
||||||
|
input_ids = tokens[0, :cur_pos].tolist()
|
||||||
|
# logits = logits_processor.process_logits(input_ids, logits)
|
||||||
|
# print(f"{logits=}")
|
||||||
|
|
||||||
if temperature > 0:
|
if temperature > 0:
|
||||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||||
next_token = sample_top_p(probs, top_p)
|
next_token = sample_top_p(probs, top_p)
|
||||||
|
@ -371,3 +396,54 @@ def sample_top_p(probs, p):
|
||||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||||
next_token = torch.gather(probs_idx, -1, next_token)
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
return next_token
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
class LogitsProcessor:
|
||||||
|
def __init__(self, token_enforcer: TokenEnforcer):
|
||||||
|
self.token_enforcer = token_enforcer
|
||||||
|
self.mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def process_logits(
|
||||||
|
self, input_ids: List[int], scores: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
token_sequence = input_ids
|
||||||
|
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||||
|
# print(f"{allowed_tokens=}")
|
||||||
|
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
|
||||||
|
if self.mask is not None:
|
||||||
|
self.mask.fill_(-math.inf)
|
||||||
|
else:
|
||||||
|
self.mask = torch.full_like(scores, -math.inf)
|
||||||
|
|
||||||
|
self.mask[:, :, allowed_tokens] = 0
|
||||||
|
scores = scores + self.mask
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def _build_regular_tokens_list(
|
||||||
|
tokenizer: Tokenizer, vocab_size: int
|
||||||
|
) -> List[Tuple[int, str, bool]]:
|
||||||
|
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||||
|
regular_tokens = []
|
||||||
|
|
||||||
|
special_token_ids = set(tokenizer.special_tokens.values())
|
||||||
|
for token_idx in range(vocab_size):
|
||||||
|
if token_idx in special_token_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
||||||
|
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
||||||
|
decoded_regular = tokenizer.decode([token_idx])
|
||||||
|
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||||
|
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||||
|
return regular_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def build_token_enforcer_tokenizer_data(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
vocab_size: int,
|
||||||
|
) -> TokenEnforcerTokenizerData:
|
||||||
|
regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size)
|
||||||
|
return TokenEnforcerTokenizerData(
|
||||||
|
regular_tokens, tokenizer.decode, tokenizer.stop_tokens
|
||||||
|
)
|
||||||
|
|
|
@ -9,36 +9,36 @@ from typing import List
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
META_REFERENCE_DEPS = [
|
||||||
|
"accelerate",
|
||||||
|
"blobfile",
|
||||||
|
"fairscale",
|
||||||
|
"torch",
|
||||||
|
"torchvision",
|
||||||
|
"transformers",
|
||||||
|
"zmq",
|
||||||
|
"lm-format-enforcer",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_type="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=META_REFERENCE_DEPS,
|
||||||
"accelerate",
|
|
||||||
"blobfile",
|
|
||||||
"fairscale",
|
|
||||||
"torch",
|
|
||||||
"torchvision",
|
|
||||||
"transformers",
|
|
||||||
"zmq",
|
|
||||||
],
|
|
||||||
module="llama_stack.providers.impls.meta_reference.inference",
|
module="llama_stack.providers.impls.meta_reference.inference",
|
||||||
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
|
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_type="meta-reference-quantized",
|
provider_type="meta-reference-quantized",
|
||||||
pip_packages=[
|
pip_packages=(
|
||||||
"accelerate",
|
META_REFERENCE_DEPS
|
||||||
"blobfile",
|
+ [
|
||||||
"fairscale",
|
"fbgemm-gpu==0.8.0",
|
||||||
"fbgemm-gpu==0.8.0",
|
]
|
||||||
"torch",
|
),
|
||||||
"torchvision",
|
|
||||||
"transformers",
|
|
||||||
"zmq",
|
|
||||||
],
|
|
||||||
module="llama_stack.providers.impls.meta_reference.inference",
|
module="llama_stack.providers.impls.meta_reference.inference",
|
||||||
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
|
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
|
||||||
),
|
),
|
||||||
|
|
|
@ -85,11 +85,24 @@ async def inference_settings(request):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerFormat(BaseModel):
|
||||||
|
first_name: str
|
||||||
|
last_name: str
|
||||||
|
year_of_birth: int
|
||||||
|
num_seasons_in_nba: int
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_messages():
|
def sample_messages():
|
||||||
|
question = "Please give me information about Michael Jordan. You MUST answer using the following json schema: "
|
||||||
|
question_with_schema = f"{question}{AnswerFormat.schema_json()}"
|
||||||
return [
|
return [
|
||||||
SystemMessage(content="You are a helpful assistant."),
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(content="What's the weather like today?"),
|
# UserMessage(content="What's the weather like today?"),
|
||||||
|
UserMessage(content=question_with_schema),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,6 +190,7 @@ async def test_chat_completion_non_streaming(inference_settings, sample_messages
|
||||||
**inference_settings["common_params"],
|
**inference_settings["common_params"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
assert isinstance(response, ChatCompletionResponse)
|
assert isinstance(response, ChatCompletionResponse)
|
||||||
assert response.completion_message.role == "assistant"
|
assert response.completion_message.role == "assistant"
|
||||||
assert isinstance(response.completion_message.content, str)
|
assert isinstance(response.completion_message.content, str)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue