mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
completion() for tgi (#295)
This commit is contained in:
parent
cb84034567
commit
3e1c3fdb3f
9 changed files with 173 additions and 35 deletions
|
@ -29,9 +29,9 @@ class OpenAICompatCompletionResponse(BaseModel):
|
|||
choices: List[OpenAICompatCompletionChoice]
|
||||
|
||||
|
||||
def get_sampling_options(request: ChatCompletionRequest) -> dict:
|
||||
def get_sampling_options(params: SamplingParams) -> dict:
|
||||
options = {}
|
||||
if params := request.sampling_params:
|
||||
if params:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(params, attr):
|
||||
options[attr] = getattr(params, attr)
|
||||
|
@ -64,7 +64,18 @@ def process_completion_response(
|
|||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||
) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||
if choice.text.endswith("<|eot_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
content=choice.text[: -len("<|eot_id|>")],
|
||||
)
|
||||
# drop suffix <eom_id> if present and return stop reason as end of message
|
||||
if choice.text.endswith("<|eom_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_message,
|
||||
content=choice.text[: -len("<|eom_id|>")],
|
||||
)
|
||||
return CompletionResponse(
|
||||
stop_reason=get_stop_reason(choice.finish_reason),
|
||||
content=choice.text,
|
||||
|
@ -95,13 +106,6 @@ async def process_completion_stream_response(
|
|||
choice = chunk.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
if finish_reason:
|
||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = text_from_choice(choice)
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
|
@ -115,6 +119,12 @@ async def process_completion_stream_response(
|
|||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
if finish_reason:
|
||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue