From e6b82a44ebffdd49d68add339108307a707cb91e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 22 Nov 2024 06:29:55 -0500 Subject: [PATCH] add chat completion support for JsonSchemaResponseFormat request_format --- .../providers/remote/inference/nvidia/nvidia.py | 1 + .../remote/inference/nvidia/openai_utils.py | 16 ++++++++++++++++ .../tests/inference/test_text_inference.py | 1 + 3 files changed, 18 insertions(+) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 1422f7a29..f38aa7112 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -159,6 +159,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): model=self.get_provider_model_id(model_id), messages=messages, sampling_params=sampling_params, + response_format=response_format, tools=tools, tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 595cf0c93..2dddeadf9 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -32,6 +32,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponseEvent, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, + JsonSchemaResponseFormat, Message, ToolCallDelta, ToolCallParseStatus, @@ -149,12 +150,22 @@ def convert_chat_completion_request( # top_k -> nvext.top_k # max_tokens -> max_tokens # repetition_penalty -> nvext.repetition_penalty + # response_format -> GrammarResponseFormat TODO(mf) + # response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema # tools -> tools # tool_choice ("auto", "required") -> tool_choice # tool_prompt_format -> TBD # stream -> stream # logprobs -> logprobs + if request.response_format and not isinstance( + request.response_format, JsonSchemaResponseFormat + ): + raise ValueError( + f"Unsupported response format: {request.response_format}. " + "Only JsonSchemaResponseFormat is supported." + ) + nvext = {} payload: Dict[str, Any] = dict( model=request.model, @@ -167,6 +178,11 @@ def convert_chat_completion_request( }, ) + if request.response_format: + # server bug - setting guided_json changes the behavior of response_format resulting in an error + # payload.update(response_format="json_object") + nvext.update(guided_json=request.response_format.json_schema) + if request.tools: payload.update( tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools] diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index b35d947ec..1193a64ed 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -198,6 +198,7 @@ class TestInference: "remote::fireworks", "remote::tgi", "remote::together", + "remote::nvidia", ): pytest.skip("Other inference providers don't support structured output yet")