add chat completion support for JsonSchemaResponseFormat request_format

This commit is contained in:
Matthew Farrellee 2024-11-22 06:29:55 -05:00
parent a6f47f1090
commit e6b82a44eb
3 changed files with 18 additions and 0 deletions

View file

@ -159,6 +159,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
model=self.get_provider_model_id(model_id),
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
JsonSchemaResponseFormat,
Message,
ToolCallDelta,
ToolCallParseStatus,
@ -149,12 +150,22 @@ def convert_chat_completion_request(
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> GrammarResponseFormat TODO(mf)
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
# tools -> tools
# tool_choice ("auto", "required") -> tool_choice
# tool_prompt_format -> TBD
# stream -> stream
# logprobs -> logprobs
if request.response_format and not isinstance(
request.response_format, JsonSchemaResponseFormat
):
raise ValueError(
f"Unsupported response format: {request.response_format}. "
"Only JsonSchemaResponseFormat is supported."
)
nvext = {}
payload: Dict[str, Any] = dict(
model=request.model,
@ -167,6 +178,11 @@ def convert_chat_completion_request(
},
)
if request.response_format:
# server bug - setting guided_json changes the behavior of response_format resulting in an error
# payload.update(response_format="json_object")
nvext.update(guided_json=request.response_format.json_schema)
if request.tools:
payload.update(
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]

View file

@ -198,6 +198,7 @@ class TestInference:
"remote::fireworks",
"remote::tgi",
"remote::together",
"remote::nvidia",
):
pytest.skip("Other inference providers don't support structured output yet")