forked from phoenix-oss/llama-stack-mirror
Add support for Structured Output / Guided decoding (#281)
Added support for structured output in the API and added a reference implementation for meta-reference. A few notes: * Two formats are specified in the API: Json schema and EBNF based grammar * Implementation only supports Json for now We use lm-format-enhancer to provide the implementation right now but may change this especially because BNF grammars aren't supported by that library. Fireworks has support for structured output and Together has limited supported for it too. Subsequent PRs will add these changes. We would like all our inference providers to provide structured output for llama models since it is an extremely important and highly sought-after need by the developers.
This commit is contained in:
parent
4c3d33e6f4
commit
c06718fbd5
16 changed files with 257 additions and 25 deletions
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Tuple
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
@ -70,11 +71,25 @@ def chat_completion_request_to_messages(
|
|||
and is_multimodal(model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
return augment_messages_for_tools_llama_3_1(request)
|
||||
messages = augment_messages_for_tools_llama_3_1(request)
|
||||
elif model.model_family == ModelFamily.llama3_2:
|
||||
return augment_messages_for_tools_llama_3_2(request)
|
||||
messages = augment_messages_for_tools_llama_3_2(request)
|
||||
else:
|
||||
return request.messages
|
||||
messages = request.messages
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
messages.append(
|
||||
UserMessage(
|
||||
content=f"Please respond in JSON format with the schema: {json.dumps(fmt.schema)}"
|
||||
)
|
||||
)
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise NotImplementedError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def augment_messages_for_tools_llama_3_1(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue