mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(utils.py): remove special characters from streaming output
This commit is contained in:
parent
a2f2fd3841
commit
6e7e409615
2 changed files with 60 additions and 34 deletions
|
@ -36,7 +36,7 @@ def test_completion_custom_provider_model_name():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
test_completion_custom_provider_model_name()
|
# test_completion_custom_provider_model_name()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_claude():
|
def test_completion_claude():
|
||||||
|
@ -221,34 +221,32 @@ def test_get_hf_task_for_model():
|
||||||
# ################### Hugging Face TGI models ########################
|
# ################### Hugging Face TGI models ########################
|
||||||
# # TGI model
|
# # TGI model
|
||||||
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
|
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
|
||||||
# def hf_test_completion_tgi():
|
def hf_test_completion_tgi():
|
||||||
# litellm.set_verbose=True
|
# litellm.set_verbose=True
|
||||||
# try:
|
try:
|
||||||
# response = litellm.completion(
|
response = completion(
|
||||||
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1",
|
model = 'huggingface/HuggingFaceH4/zephyr-7b-beta',
|
||||||
# messages=[{ "content": "Hello, how are you?","role": "user"}],
|
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||||
# api_base="https://3kk3h56912qga4-80.proxy.runpod.net",
|
)
|
||||||
# )
|
# Add any assertions here to check the response
|
||||||
# # Add any assertions here to check the response
|
print(response)
|
||||||
# print(response)
|
except Exception as e:
|
||||||
# except Exception as e:
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
# hf_test_completion_tgi()
|
# hf_test_completion_tgi()
|
||||||
|
|
||||||
# def hf_test_completion_tgi_stream():
|
def hf_test_completion_tgi_stream():
|
||||||
# try:
|
try:
|
||||||
# response = litellm.completion(
|
response = completion(
|
||||||
# model="huggingface/glaiveai/glaive-coder-7b",
|
model = 'huggingface/HuggingFaceH4/zephyr-7b-beta',
|
||||||
# messages=[{ "content": "Hello, how are you?","role": "user"}],
|
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||||
# api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud",
|
stream=True
|
||||||
# stream=True
|
)
|
||||||
# )
|
# Add any assertions here to check the response
|
||||||
# # Add any assertions here to check the response
|
print(response)
|
||||||
# print(response)
|
for chunk in response:
|
||||||
# for chunk in response:
|
print(chunk["choices"][0]["delta"]["content"])
|
||||||
# print(chunk)
|
except Exception as e:
|
||||||
# except Exception as e:
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
# hf_test_completion_tgi_stream()
|
# hf_test_completion_tgi_stream()
|
||||||
|
|
||||||
# ################### Hugging Face Conversational models ########################
|
# ################### Hugging Face Conversational models ########################
|
||||||
|
|
|
@ -3686,6 +3686,8 @@ class CustomStreamWrapper:
|
||||||
self.completion_stream = completion_stream
|
self.completion_stream = completion_stream
|
||||||
self.sent_first_chunk = False
|
self.sent_first_chunk = False
|
||||||
self.sent_last_chunk = False
|
self.sent_last_chunk = False
|
||||||
|
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||||
|
self.holding_chunk = ""
|
||||||
if self.logging_obj:
|
if self.logging_obj:
|
||||||
# Log the type of the received item
|
# Log the type of the received item
|
||||||
self.logging_obj.post_call(str(type(completion_stream)))
|
self.logging_obj.post_call(str(type(completion_stream)))
|
||||||
|
@ -3700,6 +3702,29 @@ class CustomStreamWrapper:
|
||||||
if self.logging_obj:
|
if self.logging_obj:
|
||||||
self.logging_obj.post_call(text)
|
self.logging_obj.post_call(text)
|
||||||
|
|
||||||
|
def check_special_tokens(self, chunk: str):
|
||||||
|
hold = False
|
||||||
|
if self.sent_first_chunk is True:
|
||||||
|
return hold, chunk
|
||||||
|
|
||||||
|
curr_chunk = self.holding_chunk + chunk
|
||||||
|
curr_chunk = curr_chunk.strip()
|
||||||
|
|
||||||
|
for token in self.special_tokens:
|
||||||
|
if len(curr_chunk) < len(token) and curr_chunk in token:
|
||||||
|
hold = True
|
||||||
|
elif len(curr_chunk) >= len(token):
|
||||||
|
if token in curr_chunk:
|
||||||
|
self.holding_chunk = curr_chunk.replace(token, "")
|
||||||
|
hold = True
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if hold is False: # reset
|
||||||
|
self.holding_chunk = ""
|
||||||
|
return hold, curr_chunk
|
||||||
|
|
||||||
|
|
||||||
def handle_anthropic_chunk(self, chunk):
|
def handle_anthropic_chunk(self, chunk):
|
||||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||||
text = ""
|
text = ""
|
||||||
|
@ -4100,6 +4125,9 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
model_response.model = self.model
|
model_response.model = self.model
|
||||||
if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string
|
if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string
|
||||||
|
hold, model_response_str = self.check_special_tokens(completion_obj["content"])
|
||||||
|
if hold is False:
|
||||||
|
completion_obj["content"] = model_response_str
|
||||||
if self.sent_first_chunk == False:
|
if self.sent_first_chunk == False:
|
||||||
completion_obj["role"] = "assistant"
|
completion_obj["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue