forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): output parsing chat template models
This commit is contained in:
parent
c46c193c96
commit
65c01eae23
2 changed files with 61 additions and 3 deletions
|
@ -78,6 +78,22 @@ def validate_environment(api_key, headers):
|
|||
headers = default_headers
|
||||
return headers
|
||||
|
||||
def output_parser(generated_text: str):
|
||||
"""
|
||||
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||
|
||||
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||
"""
|
||||
chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||
for token in chat_template_tokens:
|
||||
if generated_text.strip().startswith(token):
|
||||
generated_text = generated_text.replace(token, "", 1)
|
||||
if generated_text.endswith(token):
|
||||
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||
return generated_text
|
||||
|
||||
|
||||
|
||||
tgi_models_cache = None
|
||||
conv_models_cache = None
|
||||
def read_tgi_conv_models():
|
||||
|
@ -308,7 +324,7 @@ def completion(
|
|||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response[0]["generated_text"]
|
||||
] = output_parser(completion_response[0]["generated_text"])
|
||||
## GETTING LOGPROBS + FINISH REASON
|
||||
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
||||
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
|
||||
|
@ -324,7 +340,7 @@ def completion(
|
|||
for token in item["tokens"]:
|
||||
sum_logprob += token["logprob"]
|
||||
if len(item["generated_text"]) > 0:
|
||||
message_obj = Message(content=item["generated_text"], logprobs=sum_logprob)
|
||||
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob)
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
|
||||
|
@ -334,7 +350,7 @@ def completion(
|
|||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response[0]["generated_text"]
|
||||
] = output_parser(completion_response[0]["generated_text"])
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = 0
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue