WIP support structured output for Together

This commit is contained in:
Ashwin Bharambe 2024-10-22 22:30:54 -07:00
parent 2e5e46d896
commit dbfb10973f
2 changed files with 17 additions and 2 deletions

View file

@ -70,10 +70,10 @@ class TogetherInferenceAdapter(
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -96,6 +96,7 @@ class TogetherInferenceAdapter(
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -110,6 +111,7 @@ class TogetherInferenceAdapter(
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
print(r)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
@ -130,11 +132,23 @@ class TogetherInferenceAdapter(
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["respose_format"] = {
"type": "json_object",
"schema": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request),
**options,
}
async def embeddings(

View file

@ -195,6 +195,7 @@ async def test_structured_output(inference_settings):
"meta-reference",
"remote::fireworks",
"remote::tgi",
"remote::together",
):
pytest.skip("Other inference providers don't support structured output yet")