mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
add chat completion support for JsonSchemaResponseFormat request_format
This commit is contained in:
parent
a6f47f1090
commit
e6b82a44eb
3 changed files with 18 additions and 0 deletions
|
@ -159,6 +159,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
model=self.get_provider_model_id(model_id),
|
model=self.get_provider_model_id(model_id),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
|
|
@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
|
JsonSchemaResponseFormat,
|
||||||
Message,
|
Message,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
|
@ -149,12 +150,22 @@ def convert_chat_completion_request(
|
||||||
# top_k -> nvext.top_k
|
# top_k -> nvext.top_k
|
||||||
# max_tokens -> max_tokens
|
# max_tokens -> max_tokens
|
||||||
# repetition_penalty -> nvext.repetition_penalty
|
# repetition_penalty -> nvext.repetition_penalty
|
||||||
|
# response_format -> GrammarResponseFormat TODO(mf)
|
||||||
|
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
|
||||||
# tools -> tools
|
# tools -> tools
|
||||||
# tool_choice ("auto", "required") -> tool_choice
|
# tool_choice ("auto", "required") -> tool_choice
|
||||||
# tool_prompt_format -> TBD
|
# tool_prompt_format -> TBD
|
||||||
# stream -> stream
|
# stream -> stream
|
||||||
# logprobs -> logprobs
|
# logprobs -> logprobs
|
||||||
|
|
||||||
|
if request.response_format and not isinstance(
|
||||||
|
request.response_format, JsonSchemaResponseFormat
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported response format: {request.response_format}. "
|
||||||
|
"Only JsonSchemaResponseFormat is supported."
|
||||||
|
)
|
||||||
|
|
||||||
nvext = {}
|
nvext = {}
|
||||||
payload: Dict[str, Any] = dict(
|
payload: Dict[str, Any] = dict(
|
||||||
model=request.model,
|
model=request.model,
|
||||||
|
@ -167,6 +178,11 @@ def convert_chat_completion_request(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if request.response_format:
|
||||||
|
# server bug - setting guided_json changes the behavior of response_format resulting in an error
|
||||||
|
# payload.update(response_format="json_object")
|
||||||
|
nvext.update(guided_json=request.response_format.json_schema)
|
||||||
|
|
||||||
if request.tools:
|
if request.tools:
|
||||||
payload.update(
|
payload.update(
|
||||||
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||||
|
|
|
@ -198,6 +198,7 @@ class TestInference:
|
||||||
"remote::fireworks",
|
"remote::fireworks",
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
"remote::together",
|
"remote::together",
|
||||||
|
"remote::nvidia",
|
||||||
):
|
):
|
||||||
pytest.skip("Other inference providers don't support structured output yet")
|
pytest.skip("Other inference providers don't support structured output yet")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue