Add support for Structured Output / Guided decoding (#281)

Added support for structured output in the API and added a reference implementation for meta-reference.

A few notes:

* Two formats are specified in the API: Json schema and EBNF based grammar
* Implementation only supports Json for now
We use lm-format-enhancer to provide the implementation right now but may change this especially because BNF grammars aren't supported by that library.
Fireworks has support for structured output and Together has limited supported for it too. Subsequent PRs will add these changes. We would like all our inference providers to provide structured output for llama models since it is an extremely important and highly sought-after need by the developers.
This commit is contained in:
Ashwin Bharambe 2024-10-22 12:53:34 -07:00 committed by GitHub
parent 4c3d33e6f4
commit c06718fbd5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 257 additions and 25 deletions

View file

@ -74,11 +74,35 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: Optional[StopReason] = None
class ResponseFormatType(Enum):
json_schema = "json_schema"
grammar = "grammar"
class JsonResponseFormat(BaseModel):
type: Literal[ResponseFormatType.json_schema.value] = (
ResponseFormatType.json_schema.value
)
schema: Dict[str, Any]
class GrammarResponseFormat(BaseModel):
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
bnf: Dict[str, Any]
ResponseFormat = Annotated[
Union[JsonResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
@json_schema_type
class CompletionRequest(BaseModel):
model: str
content: InterleavedTextMedia
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -107,6 +131,7 @@ class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
logprobs: Optional[LogProbConfig] = None
@ -129,6 +154,7 @@ class ChatCompletionRequest(BaseModel):
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -188,6 +214,7 @@ class Inference(Protocol):
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@ -204,6 +231,7 @@ class Inference(Protocol):
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...