handle llama 2 eos tokens in streaming

This commit is contained in:
Krrish Dholakia 2023-09-18 13:44:19 -07:00
parent d81c75a0b3
commit 633e36de42
8 changed files with 36 additions and 5 deletions

View file

@ -36,7 +36,7 @@ from .exceptions import (
)
from typing import cast, List, Dict, Union, Optional
from .caching import Cache
from .llms.prompt_templates.factory import llama_2_special_tokens
####### ENVIRONMENT VARIABLES ####################
dotenv.load_dotenv() # Loading env variables using dotenv
@ -2518,7 +2518,11 @@ class CustomStreamWrapper:
if chunk.startswith("data:"):
data_json = json.loads(chunk[5:])
if "token" in data_json and "text" in data_json["token"]:
return data_json["token"]["text"]
text = data_json["token"]["text"]
if "meta-llama/Llama-2" in self.model: #clean eos tokens like </s> from the returned output text
if any(token in text for token in llama_2_special_tokens):
text = text.replace("<s>", "").replace("</s>", "")
return text
else:
return ""
return ""