Add support for Structured Output / Guided decoding (#281)

Added support for structured output in the API and added a reference implementation for meta-reference.

A few notes:

* Two formats are specified in the API: Json schema and EBNF based grammar
* Implementation only supports Json for now
We use lm-format-enhancer to provide the implementation right now but may change this especially because BNF grammars aren't supported by that library.
Fireworks has support for structured output and Together has limited supported for it too. Subsequent PRs will add these changes. We would like all our inference providers to provide structured output for llama models since it is an extremely important and highly sought-after need by the developers.
This commit is contained in:
Ashwin Bharambe 2024-10-22 12:53:34 -07:00 committed by GitHub
parent 4c3d33e6f4
commit c06718fbd5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 257 additions and 25 deletions

View file

@ -53,6 +53,7 @@ class InferenceClient(Inference):
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -63,6 +64,7 @@ class InferenceClient(Inference):
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)

View file

@ -74,11 +74,35 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: Optional[StopReason] = None
class ResponseFormatType(Enum):
json_schema = "json_schema"
grammar = "grammar"
class JsonResponseFormat(BaseModel):
type: Literal[ResponseFormatType.json_schema.value] = (
ResponseFormatType.json_schema.value
)
schema: Dict[str, Any]
class GrammarResponseFormat(BaseModel):
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
bnf: Dict[str, Any]
ResponseFormat = Annotated[
Union[JsonResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
@json_schema_type
class CompletionRequest(BaseModel):
model: str
content: InterleavedTextMedia
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -107,6 +131,7 @@ class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
logprobs: Optional[LogProbConfig] = None
@ -129,6 +154,7 @@ class ChatCompletionRequest(BaseModel):
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -188,6 +214,7 @@ class Inference(Protocol):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@ -204,6 +231,7 @@ class Inference(Protocol):
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...

View file

@ -75,6 +75,7 @@ class InferenceRouter(Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
@ -88,6 +89,7 @@ class InferenceRouter(Inference):
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -102,6 +104,7 @@ class InferenceRouter(Inference):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -110,6 +113,7 @@ class InferenceRouter(Inference):
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)

View file

@ -52,6 +52,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
@ -288,6 +289,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,

View file

@ -53,6 +53,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -63,6 +64,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,

View file

@ -56,6 +56,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -69,6 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -79,6 +81,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -115,6 +118,20 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
options = get_sampling_options(request)
options.setdefault("max_tokens", 512)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,

View file

@ -93,6 +93,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -160,6 +161,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,

View file

@ -71,6 +71,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -84,6 +85,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -94,6 +96,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -148,6 +151,17 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self.max_tokens - input_tokens - 1,
)
options = get_sampling_options(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return dict(
prompt=prompt,
stream=request.stream,

View file

@ -59,6 +59,7 @@ class TogetherInferenceAdapter(
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -69,6 +70,7 @@ class TogetherInferenceAdapter(
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,

View file

@ -80,6 +80,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
@ -90,6 +91,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,

View file

@ -26,6 +26,7 @@ class MockInferenceAPI:
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,

View file

@ -8,11 +8,12 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import math
import os
import sys
import time
from pathlib import Path
from typing import Generator, List, Optional, Union
from typing import Generator, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -34,6 +35,9 @@ from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
@ -67,7 +71,7 @@ class Llama:
def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
]
],
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
@ -171,6 +175,7 @@ class Llama:
echo: bool = False,
include_stop_token: bool = False,
print_input_tokens: bool = False,
logits_processor: Optional["LogitsProcessor"] = None,
) -> Generator:
params = self.model.params
@ -246,6 +251,9 @@ class Llama:
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logits_processor is not None:
logits = logits_processor.process_logits(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
@ -317,7 +325,11 @@ class Llama:
top_p=sampling_params.top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
)
def chat_completion(
@ -345,6 +357,11 @@ class Llama:
top_p=sampling_params.top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
)
@ -371,3 +388,64 @@ def sample_top_p(probs, p):
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
def process_logits(
self, tokens: torch.Tensor, scores: torch.Tensor
) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
if self.mask is not None:
self.mask.fill_(-math.inf)
else:
self.mask = torch.full_like(scores, -math.inf)
self.mask[:, :, allowed_tokens] = 0
scores = scores + self.mask
return scores
def get_logits_processor(
tokenizer: Tokenizer,
vocab_size: int,
response_format: Optional[ResponseFormat],
) -> Optional["LogitsProcessor"]:
if response_format is None:
return None
if response_format.type != ResponseFormatType.json_schema.value:
raise ValueError(f"Unsupported response format type {response_format.type}")
parser = JsonSchemaParser(response_format.schema)
data = TokenEnforcerTokenizerData(
_build_regular_tokens_list(tokenizer, vocab_size),
tokenizer.decode,
tokenizer.stop_tokens,
)
token_enforcer = TokenEnforcer(data, parser)
return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list(
tokenizer: Tokenizer, vocab_size: int
) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []
special_token_ids = set(tokenizer.special_tokens.values())
for token_idx in range(vocab_size):
if token_idx in special_token_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
decoded_regular = tokenizer.decode([token_idx])
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
return regular_tokens

View file

@ -71,6 +71,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
@ -81,6 +82,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -186,6 +188,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
@ -203,6 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)

View file

@ -9,36 +9,36 @@ from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
META_REFERENCE_DEPS = [
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
"lm-format-enforcer",
]
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
],
pip_packages=META_REFERENCE_DEPS,
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference-quantized",
pip_packages=[
"accelerate",
"blobfile",
"fairscale",
"fbgemm-gpu==0.8.0",
"torch",
"torchvision",
"transformers",
"zmq",
],
pip_packages=(
META_REFERENCE_DEPS
+ [
"fbgemm-gpu==0.8.0",
]
),
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
),

View file

@ -10,6 +10,8 @@ import os
import pytest
import pytest_asyncio
from pydantic import BaseModel, ValidationError
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
@ -183,6 +185,63 @@ async def test_chat_completion_non_streaming(inference_settings, sample_messages
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
async def test_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::fireworks",
"remote::tgi",
):
pytest.skip("Other inference providers don't support structured output yet")
class AnswerFormat(BaseModel):
first_name: str
last_name: str
year_of_birth: int
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
response_format=JsonResponseFormat(
schema=AnswerFormat.model_json_schema(),
),
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
answer = AnswerFormat.parse_raw(response.completion_message.content)
assert answer.first_name == "Michael"
assert answer.last_name == "Jordan"
assert answer.year_of_birth == 1963
assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion(
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,
**inference_settings["common_params"],
)
assert isinstance(response, ChatCompletionResponse)
assert isinstance(response.completion_message.content, str)
with pytest.raises(ValidationError):
AnswerFormat.parse_raw(response.completion_message.content)
@pytest.mark.asyncio
async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Tuple
from llama_models.llama3.api.chat_format import ChatFormat
@ -70,11 +71,25 @@ def chat_completion_request_to_messages(
and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return augment_messages_for_tools_llama_3_1(request)
messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family == ModelFamily.llama3_2:
return augment_messages_for_tools_llama_3_2(request)
messages = augment_messages_for_tools_llama_3_2(request)
else:
return request.messages
messages = request.messages
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
messages.append(
UserMessage(
content=f"Please respond in JSON format with the schema: {json.dumps(fmt.schema)}"
)
)
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
return messages
def augment_messages_for_tools_llama_3_1(