mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
add chat completion support for JsonSchemaResponseFormat request_format
This commit is contained in:
parent
a6f47f1090
commit
e6b82a44eb
3 changed files with 18 additions and 0 deletions
|
@ -159,6 +159,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
model=self.get_provider_model_id(model_id),
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
|
|
|
@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
|
@ -149,12 +150,22 @@ def convert_chat_completion_request(
|
|||
# top_k -> nvext.top_k
|
||||
# max_tokens -> max_tokens
|
||||
# repetition_penalty -> nvext.repetition_penalty
|
||||
# response_format -> GrammarResponseFormat TODO(mf)
|
||||
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
|
||||
# tools -> tools
|
||||
# tool_choice ("auto", "required") -> tool_choice
|
||||
# tool_prompt_format -> TBD
|
||||
# stream -> stream
|
||||
# logprobs -> logprobs
|
||||
|
||||
if request.response_format and not isinstance(
|
||||
request.response_format, JsonSchemaResponseFormat
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported response format: {request.response_format}. "
|
||||
"Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
nvext = {}
|
||||
payload: Dict[str, Any] = dict(
|
||||
model=request.model,
|
||||
|
@ -167,6 +178,11 @@ def convert_chat_completion_request(
|
|||
},
|
||||
)
|
||||
|
||||
if request.response_format:
|
||||
# server bug - setting guided_json changes the behavior of response_format resulting in an error
|
||||
# payload.update(response_format="json_object")
|
||||
nvext.update(guided_json=request.response_format.json_schema)
|
||||
|
||||
if request.tools:
|
||||
payload.update(
|
||||
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||
|
|
|
@ -198,6 +198,7 @@ class TestInference:
|
|||
"remote::fireworks",
|
||||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::nvidia",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support structured output yet")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue