mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
add completion() for ollama (#280)
This commit is contained in:
parent
e2a5a2e10d
commit
1d241bf3fe
5 changed files with 138 additions and 15 deletions
|
@ -34,6 +34,8 @@ def get_sampling_options(request: ChatCompletionRequest) -> dict:
|
|||
if params := request.sampling_params:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(params, attr):
|
||||
if attr == "max_tokens":
|
||||
options["num_predict"] = getattr(params, attr)
|
||||
options[attr] = getattr(params, attr)
|
||||
|
||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||
|
@ -49,25 +51,35 @@ def text_from_choice(choice) -> str:
|
|||
return choice.text
|
||||
|
||||
|
||||
def get_stop_reason(finish_reason: str) -> StopReason:
|
||||
if finish_reason in ["stop", "eos"]:
|
||||
return StopReason.end_of_turn
|
||||
elif finish_reason == "eom":
|
||||
return StopReason.end_of_message
|
||||
elif finish_reason == "length":
|
||||
return StopReason.out_of_tokens
|
||||
|
||||
return StopReason.out_of_tokens
|
||||
|
||||
|
||||
def process_completion_response(
|
||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||
) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
return CompletionResponse(
|
||||
stop_reason=get_stop_reason(choice.finish_reason),
|
||||
content=choice.text,
|
||||
)
|
||||
|
||||
|
||||
def process_chat_completion_response(
|
||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
stop_reason = None
|
||||
if reason := choice.finish_reason:
|
||||
if reason in ["stop", "eos"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif reason == "eom":
|
||||
stop_reason = StopReason.end_of_message
|
||||
elif reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = formatter.decode_assistant_message_from_content(
|
||||
text_from_choice(choice), stop_reason
|
||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
|
@ -75,6 +87,43 @@ def process_chat_completion_response(
|
|||
)
|
||||
|
||||
|
||||
async def process_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
||||
) -> AsyncGenerator:
|
||||
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
if finish_reason:
|
||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = text_from_choice(choice)
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
|
||||
async def process_chat_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
||||
) -> AsyncGenerator:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue