add support for TGI

This commit is contained in:
Ashwin Bharambe 2024-10-21 23:26:31 -07:00
parent fe20a69f24
commit 510269e4c5
2 changed files with 14 additions and 1 deletions

View file

@ -82,10 +82,10 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -96,6 +96,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
tools=tools or [], tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
@ -150,6 +151,17 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self.max_tokens - input_tokens - 1, self.max_tokens - input_tokens - 1,
) )
options = get_sampling_options(request) options = get_sampling_options(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return dict( return dict(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,

View file

@ -194,6 +194,7 @@ async def test_structured_output(inference_settings):
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::fireworks", "remote::fireworks",
"remote::tgi",
): ):
pytest.skip("Other inference providers don't support structured output yet") pytest.skip("Other inference providers don't support structured output yet")