diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 85bbb34b2..f19181320 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -82,10 +82,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -96,6 +96,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -150,6 +151,17 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): self.max_tokens - input_tokens - 1, ) options = get_sampling_options(request) + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + options["grammar"] = { + "type": "json", + "value": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise ValueError("Grammar response format not supported yet") + else: + raise ValueError(f"Unexpected response format: {fmt.type}") + return dict( prompt=prompt, stream=request.stream, diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index c6355b2dd..e89f672b1 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -194,6 +194,7 @@ async def test_structured_output(inference_settings): if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::fireworks", + "remote::tgi", ): pytest.skip("Other inference providers don't support structured output yet")