From 5570a632482981e87ad74b97775e9577dba691b8 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 23 Oct 2024 12:06:25 -0700 Subject: [PATCH] completion() for tgi --- .../providers/adapters/inference/tgi/tgi.py | 87 ++++++++++++++++++- .../tests/inference/test_inference.py | 1 + .../utils/inference/openai_compat.py | 13 ++- .../utils/inference/prompt_adapter.py | 7 ++ 4 files changed, 100 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index f19181320..72b5dd1e6 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -24,9 +24,13 @@ from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionResponse, process_chat_completion_response, process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_model_input_info, + completion_request_to_prompt, + completion_request_to_prompt_model_input_info, ) from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig @@ -75,7 +79,88 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - raise NotImplementedError() + request = CompletionRequest( + model=model, + content=content, + sampling_params=sampling_params, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + def _get_params_for_completion(self, request: CompletionRequest) -> dict: + prompt, input_tokens = completion_request_to_prompt_model_input_info( + request, self.formatter + ) + max_new_tokens = min( + request.sampling_params.max_tokens or (self.max_tokens - input_tokens), + self.max_tokens - input_tokens - 1, + ) + + options = get_sampling_options(request) + # delete key "max_tokens" from options since its not supported by the API + options.pop("max_tokens", None) + + if fmt := request.response_format: + if fmt.type == ResponseFormatType.json_schema.value: + options["grammar"] = { + "type": "json", + "value": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise ValueError("Grammar response format not supported yet") + else: + raise ValueError(f"Unexpected response format: {fmt.type}") + + return dict( + prompt=prompt, + stream=request.stream, + details=True, + max_new_tokens=max_new_tokens, + stop_sequences=["<|eom_id|>", "<|eot_id|>"], + **options, + ) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params_for_completion(request) + + async def _generate_and_convert_to_openai_compat(): + s = await self.client.text_generation(**params) + async for chunk in s: + + token_result = chunk.token + finish_reason = None + if chunk.details: + finish_reason = chunk.details.finish_reason + + choice = OpenAICompatCompletionChoice( + text=token_result.text, finish_reason=finish_reason + ) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + + async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params_for_completion(request) + r = await self.client.text_generation(**params) + + choice = OpenAICompatCompletionChoice( + finish_reason=r.details.finish_reason, + text="".join(t.text for t in r.details.tokens), + ) + + response = OpenAICompatCompletionResponse( + choices=[choice], + ) + + return process_completion_response(response, self.formatter) async def chat_completion( self, diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index ad49448e2..8a1aadd33 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -137,6 +137,7 @@ async def test_completion(inference_settings): if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::ollama", + "remote::tgi", ): pytest.skip("Other inference providers don't support completion() yet") diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 22ae8a717..f810dc40a 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -95,13 +95,6 @@ async def process_completion_stream_response( choice = chunk.choices[0] finish_reason = choice.finish_reason - if finish_reason: - if finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - text = text_from_choice(choice) if text == "<|eot_id|>": stop_reason = StopReason.end_of_turn @@ -115,6 +108,12 @@ async def process_completion_stream_response( delta=text, stop_reason=stop_reason, ) + if finish_reason: + if finish_reason in ["stop", "eos", "eos_token"]: + stop_reason = StopReason.end_of_turn + elif finish_reason == "length": + stop_reason = StopReason.out_of_tokens + break yield CompletionResponseStreamChunk( delta="", diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 48f1df02f..d204ab728 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -31,6 +31,13 @@ def completion_request_to_prompt( return formatter.tokenizer.decode(model_input.tokens) +def completion_request_to_prompt_model_input_info( + request: CompletionRequest, formatter: ChatFormat +) -> Tuple[str, int]: + model_input = formatter.encode_content(request.content) + return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens)) + + def chat_completion_request_to_prompt( request: ChatCompletionRequest, formatter: ChatFormat ) -> str: