forked from phoenix/litellm-mirror
fix(utils.py): ensure streaming output parsing only applied for hf / sagemaker models
selectively applies the <s> </s> checking
This commit is contained in:
parent
53df916f69
commit
7d0086d742
2 changed files with 23 additions and 0 deletions
|
@ -220,6 +220,20 @@ tools_schema = [
|
||||||
# test_completion_cohere_stream()
|
# test_completion_cohere_stream()
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_azure_stream_special_char():
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Respond with the '<' sign and nothing else."}
|
||||||
|
]
|
||||||
|
response = completion(model="azure/chatgpt-v-2", messages=messages, stream=True)
|
||||||
|
response_str = ""
|
||||||
|
for part in response:
|
||||||
|
response_str += part.choices[0].delta.content or ""
|
||||||
|
|
||||||
|
print(f"response_str: {response_str}")
|
||||||
|
assert len(response_str) > 0
|
||||||
|
raise Exception("it worked")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_cohere_stream_bad_key():
|
def test_completion_cohere_stream_bad_key():
|
||||||
try:
|
try:
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
|
|
@ -8856,7 +8856,16 @@ class CustomStreamWrapper:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
|
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
|
||||||
|
"""
|
||||||
|
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
|
||||||
|
"""
|
||||||
hold = False
|
hold = False
|
||||||
|
if (
|
||||||
|
self.custom_llm_provider != "huggingface"
|
||||||
|
and self.custom_llm_provider != "sagemaker"
|
||||||
|
):
|
||||||
|
return hold, chunk
|
||||||
|
|
||||||
if finish_reason:
|
if finish_reason:
|
||||||
for token in self.special_tokens:
|
for token in self.special_tokens:
|
||||||
if token in chunk:
|
if token in chunk:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue