add completion() for ollama (#280)

This commit is contained in:
Dinesh Yeduguru 2024-10-21 22:26:33 -07:00 committed by GitHub
parent e2a5a2e10d
commit 1d241bf3fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 138 additions and 15 deletions

View file

@ -34,6 +34,8 @@ def get_sampling_options(request: ChatCompletionRequest) -> dict:
if params := request.sampling_params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(params, attr):
if attr == "max_tokens":
options["num_predict"] = getattr(params, attr)
options[attr] = getattr(params, attr)
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
@ -49,25 +51,35 @@ def text_from_choice(choice) -> str:
return choice.text
def get_stop_reason(finish_reason: str) -> StopReason:
if finish_reason in ["stop", "eos"]:
return StopReason.end_of_turn
elif finish_reason == "eom":
return StopReason.end_of_message
elif finish_reason == "length":
return StopReason.out_of_tokens
return StopReason.out_of_tokens
def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse:
choice = response.choices[0]
return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text,
)
def process_chat_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> ChatCompletionResponse:
choice = response.choices[0]
stop_reason = None
if reason := choice.finish_reason:
if reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif reason == "eom":
stop_reason = StopReason.end_of_message
elif reason == "length":
stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
completion_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), stop_reason
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
return ChatCompletionResponse(
completion_message=completion_message,
@ -75,6 +87,43 @@ def process_chat_completion_response(
)
async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator:
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice)
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
)
yield CompletionResponseStreamChunk(
delta="",
stop_reason=stop_reason,
)
async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator: