completion() for tgi

This commit is contained in:
Dinesh Yeduguru 2024-10-23 12:06:25 -07:00
parent 21f2e9adf5
commit 5570a63248
4 changed files with 100 additions and 8 deletions

View file

@ -24,9 +24,13 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
completion_request_to_prompt,
completion_request_to_prompt_model_input_info,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
@ -75,7 +79,88 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = completion_request_to_prompt_model_input_info(
request, self.formatter
)
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
options = get_sampling_options(request)
# delete key "max_tokens" from options since its not supported by the API
options.pop("max_tokens", None)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=max_new_tokens,
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
if chunk.details:
finish_reason = chunk.details.finish_reason
choice = OpenAICompatCompletionChoice(
text=token_result.text, finish_reason=finish_reason
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
r = await self.client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
text="".join(t.text for t in r.details.tokens),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response, self.formatter)
async def chat_completion(
self,

View file

@ -137,6 +137,7 @@ async def test_completion(inference_settings):
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
"remote::tgi",
):
pytest.skip("Other inference providers don't support completion() yet")

View file

@ -95,13 +95,6 @@ async def process_completion_stream_response(
choice = chunk.choices[0]
finish_reason = choice.finish_reason
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice)
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
@ -115,6 +108,12 @@ async def process_completion_stream_response(
delta=text,
stop_reason=stop_reason,
)
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
yield CompletionResponseStreamChunk(
delta="",

View file

@ -31,6 +31,13 @@ def completion_request_to_prompt(
return formatter.tokenizer.decode(model_input.tokens)
def completion_request_to_prompt_model_input_info(
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
model_input = formatter.encode_content(request.content)
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat
) -> str: