mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(feat) text completion set top_n_tokens for tgi
This commit is contained in:
parent
19737f95c5
commit
8ca7af3a63
1 changed files with 14 additions and 13 deletions
|
@ -1821,7 +1821,20 @@ def text_completion(*args, **kwargs):
|
||||||
if "engine" in kwargs:
|
if "engine" in kwargs:
|
||||||
kwargs["model"] = kwargs["engine"]
|
kwargs["model"] = kwargs["engine"]
|
||||||
kwargs.pop("engine")
|
kwargs.pop("engine")
|
||||||
|
|
||||||
|
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
||||||
|
if kwargs.get("echo", False) == True and (kwargs["model"] not in litellm.open_ai_text_completion_models):
|
||||||
|
# for tgi llms
|
||||||
|
if "top_n_tokens" not in kwargs:
|
||||||
|
kwargs["top_n_tokens"] = 3
|
||||||
if "prompt" in kwargs:
|
if "prompt" in kwargs:
|
||||||
|
if type(kwargs["prompt"]) == list:
|
||||||
|
new_prompt = ""
|
||||||
|
for chunk in kwargs["prompt"]:
|
||||||
|
decoded = litellm.utils.decode(model="gpt2", tokens=chunk)
|
||||||
|
new_prompt+= decoded
|
||||||
|
kwargs["prompt"] = new_prompt
|
||||||
|
|
||||||
messages = [{"role": "system", "content": kwargs["prompt"]}]
|
messages = [{"role": "system", "content": kwargs["prompt"]}]
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs.pop("prompt")
|
kwargs.pop("prompt")
|
||||||
|
@ -1834,19 +1847,7 @@ def text_completion(*args, **kwargs):
|
||||||
# only supported for TGI models
|
# only supported for TGI models
|
||||||
try:
|
try:
|
||||||
raw_response = response._hidden_params.get("original_response", None)
|
raw_response = response._hidden_params.get("original_response", None)
|
||||||
tokens = []
|
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
|
||||||
token_logprobs = []
|
|
||||||
if "prefill" in raw_response[0]["details"]:
|
|
||||||
prefills = raw_response[0]["details"]['prefill']
|
|
||||||
for prefill in prefills:
|
|
||||||
tokens.append(prefill['text'])
|
|
||||||
token_logprobs.append(prefill['logprob'])
|
|
||||||
new_tokens = [token['text'] for token in raw_response[0]['details']['tokens']]
|
|
||||||
new_token_logprobs = [token['logprob'] for token in raw_response[0]['details']['tokens']]
|
|
||||||
transformed_logprobs = {
|
|
||||||
"tokens": tokens + new_tokens,
|
|
||||||
"token_logprobs": token_logprobs + new_token_logprobs
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("LiteLLM non blocking exception", e)
|
print("LiteLLM non blocking exception", e)
|
||||||
formatted_response_obj = {
|
formatted_response_obj = {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue