fix(huggingface_restapi.py): output parsing chat template models

This commit is contained in:
Krrish Dholakia 2023-11-06 11:42:57 -08:00
parent c46c193c96
commit 65c01eae23
2 changed files with 61 additions and 3 deletions

View file

@ -78,6 +78,22 @@ def validate_environment(api_key, headers):
headers = default_headers
return headers
def output_parser(generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
tgi_models_cache = None
conv_models_cache = None
def read_tgi_conv_models():
@ -308,7 +324,7 @@ def completion(
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
] = output_parser(completion_response[0]["generated_text"])
## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
@ -324,7 +340,7 @@ def completion(
for token in item["tokens"]:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(content=item["generated_text"], logprobs=sum_logprob)
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob)
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
@ -334,7 +350,7 @@ def completion(
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
] = output_parser(completion_response[0]["generated_text"])
## CALCULATING USAGE
prompt_tokens = 0
try: