From c72ce9e726d20b7736b44ef1702bc1191baee743 Mon Sep 17 00:00:00 2001 From: Hassan El Mghari Date: Mon, 26 Aug 2024 21:24:00 -0400 Subject: [PATCH] accounting for eos --- llama_toolchain/inference/together/together.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/llama_toolchain/inference/together/together.py b/llama_toolchain/inference/together/together.py index adbbb7ecf..e7ccf623e 100644 --- a/llama_toolchain/inference/together/together.py +++ b/llama_toolchain/inference/together/together.py @@ -115,7 +115,10 @@ class TogetherInference(Inference): ) stop_reason = None if r.choices[0].finish_reason: - if r.choices[0].finish_reason == "stop": + if ( + r.choices[0].finish_reason == "stop" + or r.choices[0].finish_reason == "eos" + ): stop_reason = StopReason.end_of_turn elif r.choices[0].finish_reason == "length": stop_reason = StopReason.out_of_tokens @@ -147,7 +150,11 @@ class TogetherInference(Inference): **options, ): if chunk.choices[0].finish_reason: - if stop_reason is None and chunk.choices[0].finish_reason == "stop": + if ( + stop_reason is None and chunk.choices[0].finish_reason == "stop" + ) or ( + stop_reason is None and chunk.choices[0].finish_reason == "eos" + ): stop_reason = StopReason.end_of_turn elif ( stop_reason is None