fix(utils.py): ensure streaming output parsing only applied for hf / sagemaker models

selectively applies the <s>
</s> checking
This commit is contained in:
Krrish Dholakia 2024-04-17 17:43:41 -07:00
parent 53df916f69
commit 7d0086d742
2 changed files with 23 additions and 0 deletions

View file

@ -220,6 +220,20 @@ tools_schema = [
# test_completion_cohere_stream() # test_completion_cohere_stream()
def test_completion_azure_stream_special_char():
messages = [
{"role": "user", "content": "Respond with the '<' sign and nothing else."}
]
response = completion(model="azure/chatgpt-v-2", messages=messages, stream=True)
response_str = ""
for part in response:
response_str += part.choices[0].delta.content or ""
print(f"response_str: {response_str}")
assert len(response_str) > 0
raise Exception("it worked")
def test_completion_cohere_stream_bad_key(): def test_completion_cohere_stream_bad_key():
try: try:
litellm.cache = None litellm.cache = None

View file

@ -8856,7 +8856,16 @@ class CustomStreamWrapper:
raise e raise e
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
"""
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
"""
hold = False hold = False
if (
self.custom_llm_provider != "huggingface"
and self.custom_llm_provider != "sagemaker"
):
return hold, chunk
if finish_reason: if finish_reason:
for token in self.special_tokens: for token in self.special_tokens:
if token in chunk: if token in chunk: