diff --git a/litellm/main.py b/litellm/main.py index ab8548bc7b..3732d89f94 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1821,7 +1821,20 @@ def text_completion(*args, **kwargs): if "engine" in kwargs: kwargs["model"] = kwargs["engine"] kwargs.pop("engine") + + # if echo == True, for TGI llms we need to set top_n_tokens to 3 + if kwargs.get("echo", False) == True and (kwargs["model"] not in litellm.open_ai_text_completion_models): + # for tgi llms + if "top_n_tokens" not in kwargs: + kwargs["top_n_tokens"] = 3 if "prompt" in kwargs: + if type(kwargs["prompt"]) == list: + new_prompt = "" + for chunk in kwargs["prompt"]: + decoded = litellm.utils.decode(model="gpt2", tokens=chunk) + new_prompt+= decoded + kwargs["prompt"] = new_prompt + messages = [{"role": "system", "content": kwargs["prompt"]}] kwargs["messages"] = messages kwargs.pop("prompt") @@ -1834,19 +1847,7 @@ def text_completion(*args, **kwargs): # only supported for TGI models try: raw_response = response._hidden_params.get("original_response", None) - tokens = [] - token_logprobs = [] - if "prefill" in raw_response[0]["details"]: - prefills = raw_response[0]["details"]['prefill'] - for prefill in prefills: - tokens.append(prefill['text']) - token_logprobs.append(prefill['logprob']) - new_tokens = [token['text'] for token in raw_response[0]['details']['tokens']] - new_token_logprobs = [token['logprob'] for token in raw_response[0]['details']['tokens']] - transformed_logprobs = { - "tokens": tokens + new_tokens, - "token_logprobs": token_logprobs + new_token_logprobs - } + transformed_logprobs = litellm.utils.transform_logprobs(raw_response) except Exception as e: print("LiteLLM non blocking exception", e) formatted_response_obj = {