From a8711dc5c28f1e90f250e8753d4296c00ebd6d43 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Sep 2023 20:53:32 -0700 Subject: [PATCH] fix hf tgi best of bug --- litellm/llms/huggingface_restapi.py | 23 +++++++++++------------ pyproject.toml | 2 +- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index cf49635e8..ba54b4eb9 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -173,6 +173,16 @@ def completion( "content" ] = completion_response["generated_text"] elif task == "text-generation-inference": + model_response["choices"][0]["message"][ + "content" + ] = completion_response[0]["generated_text"] + ## GETTING LOGPROBS + FINISH REASON + if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: + model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] + sum_logprob = 0 + for token in completion_response[0]["details"]["tokens"]: + sum_logprob += token["logprob"] + model_response["choices"][0]["message"]["logprobs"] = sum_logprob if "best_of" in optional_params and optional_params["best_of"] > 1: if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]: choices_list = [] @@ -183,18 +193,7 @@ def completion( message_obj = Message(content=item["generated_text"], logprobs=sum_logprob) choice_obj = Choices(finish_reason=item["finish_reason"], index=idx, message=message_obj) choices_list.append(choice_obj) - model_response["choices"] = choices_list - else: - model_response["choices"][0]["message"][ - "content" - ] = completion_response[0]["generated_text"] - ## GETTING LOGPROBS + FINISH REASON - if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: - model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] - sum_logprob = 0 - for token in completion_response[0]["details"]["tokens"]: - sum_logprob += token["logprob"] - model_response["choices"][0]["message"]["logprobs"] = sum_logprob + model_response["choices"].extend(choices_list) else: model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] ## CALCULATING USAGE diff --git a/pyproject.toml b/pyproject.toml index 8bcdc880a..5b537bfe8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.719" +version = "0.1.720" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"