mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 23:29:43 +00:00
completion() for tgi
This commit is contained in:
parent
21f2e9adf5
commit
5570a63248
4 changed files with 100 additions and 8 deletions
|
@ -24,9 +24,13 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAICompatCompletionResponse,
|
OpenAICompatCompletionResponse,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
|
process_completion_response,
|
||||||
|
process_completion_stream_response,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_model_input_info,
|
chat_completion_request_to_model_input_info,
|
||||||
|
completion_request_to_prompt,
|
||||||
|
completion_request_to_prompt_model_input_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||||
|
@ -75,7 +79,88 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
request = CompletionRequest(
|
||||||
|
model=model,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
|
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||||
|
prompt, input_tokens = completion_request_to_prompt_model_input_info(
|
||||||
|
request, self.formatter
|
||||||
|
)
|
||||||
|
max_new_tokens = min(
|
||||||
|
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||||
|
self.max_tokens - input_tokens - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
options = get_sampling_options(request)
|
||||||
|
# delete key "max_tokens" from options since its not supported by the API
|
||||||
|
options.pop("max_tokens", None)
|
||||||
|
|
||||||
|
if fmt := request.response_format:
|
||||||
|
if fmt.type == ResponseFormatType.json_schema.value:
|
||||||
|
options["grammar"] = {
|
||||||
|
"type": "json",
|
||||||
|
"value": fmt.schema,
|
||||||
|
}
|
||||||
|
elif fmt.type == ResponseFormatType.grammar.value:
|
||||||
|
raise ValueError("Grammar response format not supported yet")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected response format: {fmt.type}")
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
prompt=prompt,
|
||||||
|
stream=request.stream,
|
||||||
|
details=True,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||||
|
**options,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
params = self._get_params_for_completion(request)
|
||||||
|
|
||||||
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
s = await self.client.text_generation(**params)
|
||||||
|
async for chunk in s:
|
||||||
|
|
||||||
|
token_result = chunk.token
|
||||||
|
finish_reason = None
|
||||||
|
if chunk.details:
|
||||||
|
finish_reason = chunk.details.finish_reason
|
||||||
|
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
text=token_result.text, finish_reason=finish_reason
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
|
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
params = self._get_params_for_completion(request)
|
||||||
|
r = await self.client.text_generation(**params)
|
||||||
|
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=r.details.finish_reason,
|
||||||
|
text="".join(t.text for t in r.details.tokens),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
return process_completion_response(response, self.formatter)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -137,6 +137,7 @@ async def test_completion(inference_settings):
|
||||||
if provider.__provider_spec__.provider_type not in (
|
if provider.__provider_spec__.provider_type not in (
|
||||||
"meta-reference",
|
"meta-reference",
|
||||||
"remote::ollama",
|
"remote::ollama",
|
||||||
|
"remote::tgi",
|
||||||
):
|
):
|
||||||
pytest.skip("Other inference providers don't support completion() yet")
|
pytest.skip("Other inference providers don't support completion() yet")
|
||||||
|
|
||||||
|
|
|
@ -95,13 +95,6 @@ async def process_completion_stream_response(
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
finish_reason = choice.finish_reason
|
finish_reason = choice.finish_reason
|
||||||
|
|
||||||
if finish_reason:
|
|
||||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif finish_reason == "length":
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
break
|
|
||||||
|
|
||||||
text = text_from_choice(choice)
|
text = text_from_choice(choice)
|
||||||
if text == "<|eot_id|>":
|
if text == "<|eot_id|>":
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
|
@ -115,6 +108,12 @@ async def process_completion_stream_response(
|
||||||
delta=text,
|
delta=text,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
|
if finish_reason:
|
||||||
|
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif finish_reason == "length":
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
break
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta="",
|
delta="",
|
||||||
|
|
|
@ -31,6 +31,13 @@ def completion_request_to_prompt(
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def completion_request_to_prompt_model_input_info(
|
||||||
|
request: CompletionRequest, formatter: ChatFormat
|
||||||
|
) -> Tuple[str, int]:
|
||||||
|
model_input = formatter.encode_content(request.content)
|
||||||
|
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_request_to_prompt(
|
def chat_completion_request_to_prompt(
|
||||||
request: ChatCompletionRequest, formatter: ChatFormat
|
request: ChatCompletionRequest, formatter: ChatFormat
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue