From ffb561070d912d25173d1f034d1e91f965bf3364 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 22 Oct 2024 22:36:38 -0700 Subject: [PATCH] Support structured output for Together (#289) --- .../adapters/inference/together/together.py | 17 +++++++++++++++-- .../providers/tests/inference/test_inference.py | 1 + 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index f88e4c4c2..2f258e620 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -70,10 +70,10 @@ class TogetherInferenceAdapter( model: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -96,6 +96,7 @@ class TogetherInferenceAdapter( tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, + response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -130,11 +131,23 @@ class TogetherInferenceAdapter( yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: + options = get_sampling_options(request) + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise NotImplementedError("Grammar response format not supported yet") + else: + raise ValueError(f"Unknown response format {fmt.type}") + return { "model": self.map_to_provider_model(request.model), "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request), + **options, } async def embeddings( diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index e89f672b1..ad49448e2 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -195,6 +195,7 @@ async def test_structured_output(inference_settings): "meta-reference", "remote::fireworks", "remote::tgi", + "remote::together", ): pytest.skip("Other inference providers don't support structured output yet")