diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 0f4034478..57f3db802 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -100,6 +100,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + response_format=response_format, ) if stream: return self._stream_chat_completion(request, self.client) @@ -180,6 +181,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self.formatter, ) + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + input_dict["extra_body"] = { + "guided_json": request.response_format.json_schema + } + elif fmt.type == ResponseFormatType.grammar.value: + raise NotImplementedError("Grammar response format not supported yet") + else: + raise ValueError(f"Unknown response format {fmt.type}") + return { "model": request.model, **input_dict, diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index aa2f0b413..b84761219 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -140,6 +140,7 @@ class TestInference: "remote::tgi", "remote::together", "remote::fireworks", + "remote::vllm", "remote::cerebras", ): pytest.skip( @@ -200,6 +201,7 @@ class TestInference: "remote::fireworks", "remote::tgi", "remote::together", + "remote::vllm", "remote::nvidia", ): pytest.skip("Other inference providers don't support structured output yet")