diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc
index da70c6bd2..3a881f73d 100644
Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ
diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py
index c8c423db2..d47a7486d 100644
--- a/litellm/llms/prompt_templates/factory.py
+++ b/litellm/llms/prompt_templates/factory.py
@@ -22,6 +22,28 @@ def llama_2_chat_pt(messages):
)
return prompt
+def mistral_instruct_pt(messages):
+ prompt = custom_prompt(
+ initial_prompt_value="",
+ role_dict={
+ "system": {
+ "pre_message": "[INST]",
+ "post_message": "[/INST]"
+ },
+ "user": {
+ "pre_message": "[INST]",
+ "post_message": "[/INST]"
+ },
+ "assistant": {
+ "pre_message": "[INST]",
+ "post_message": "[/INST]"
+ }
+ },
+ final_prompt_value="",
+ messages=messages
+ )
+ return prompt
+
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages):
prompt = ""
@@ -116,4 +138,6 @@ def prompt_factory(model: str, messages: list):
return phind_codellama_pt(messages=messages)
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
return llama_2_chat_pt(messages=messages)
+ elif "mistralai/mistral" in model and "instruct" in model:
+ return mistral_instruct_pt(messages=messages)
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
\ No newline at end of file
diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py
index 760ffa2d9..0b55e868a 100644
--- a/litellm/tests/test_completion.py
+++ b/litellm/tests/test_completion.py
@@ -199,9 +199,9 @@ def test_get_hf_task_for_model():
# def hf_test_completion_tgi():
# try:
# response = litellm.completion(
-# model="huggingface/glaiveai/glaive-coder-7b",
+# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1",
# messages=[{ "content": "Hello, how are you?","role": "user"}],
-# api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud",
+# api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud",
# )
# # Add any assertions here to check the response
# print(response)
@@ -646,16 +646,7 @@ def test_completion_azure_deployment_id():
# pytest.fail(f"Error occurred: {e}")
# test_completion_anthropic_litellm_proxy()
-# def test_hf_conversational_task():
-# try:
-# messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
-# # e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
-# response = completion(model="huggingface/facebook/blenderbot-400M-distill", messages=messages, task="conversational")
-# print(f"response: {response}")
-# except Exception as e:
-# pytest.fail(f"Error occurred: {e}")
-# test_hf_conversational_task()
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
# def test_completion_replicate_llama_2():
@@ -792,7 +783,7 @@ def test_completion_bedrock_claude():
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
-test_completion_bedrock_claude()
+# test_completion_bedrock_claude()
def test_completion_bedrock_claude_stream():
print("calling claude")
diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py
index b628e959d..f4a3db36e 100644
--- a/litellm/tests/test_streaming.py
+++ b/litellm/tests/test_streaming.py
@@ -314,33 +314,58 @@ def test_completion_cohere_stream_bad_key():
# test_completion_nlp_cloud_bad_key()
-# def test_completion_hf_stream():
-# try:
-# messages = [
-# {
-# "content": "Hello! How are you today?",
-# "role": "user"
-# },
-# ]
-# response = completion(
-# model="huggingface/meta-llama/Llama-2-7b-chat-hf", messages=messages, api_base="https://a8l9e3ucxinyl3oj.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
-# )
-# complete_response = ""
-# # Add any assertions here to check the response
-# for idx, chunk in enumerate(response):
-# chunk, finished = streaming_format_tests(idx, chunk)
-# if finished:
-# break
-# complete_response += chunk
-# if complete_response.strip() == "":
-# raise Exception("Empty response received")
-# print(f"completion_response: {complete_response}")
-# except InvalidRequestError as e:
-# pass
-# except Exception as e:
-# pytest.fail(f"Error occurred: {e}")
+def test_completion_hf_stream():
+ try:
+ litellm.set_verbose = True
+ # messages = [
+ # {
+ # "content": "Hello! How are you today?",
+ # "role": "user"
+ # },
+ # ]
+ # response = completion(
+ # model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
+ # )
+ # complete_response = ""
+ # # Add any assertions here to check the response
+ # for idx, chunk in enumerate(response):
+ # chunk, finished = streaming_format_tests(idx, chunk)
+ # if finished:
+ # break
+ # complete_response += chunk
+ # if complete_response.strip() == "":
+ # raise Exception("Empty response received")
+ # completion_response_1 = complete_response
+ messages = [
+ {
+ "content": "Hello! How are you today?",
+ "role": "user"
+ },
+ {
+ "content": "I'm doing well, thank you for asking! I'm excited to be here and help you with any questions or concerns you may have. What can I assist you with today?",
+ "role": "assistant"
+ },
+ ]
+ response = completion(
+ model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
+ )
+ complete_response = ""
+ # Add any assertions here to check the response
+ for idx, chunk in enumerate(response):
+ chunk, finished = streaming_format_tests(idx, chunk)
+ if finished:
+ break
+ complete_response += chunk
+ if complete_response.strip() == "":
+ raise Exception("Empty response received")
+ # print(f"completion_response_1: {completion_response_1}")
+ print(f"completion_response: {complete_response}")
+ except InvalidRequestError as e:
+ pass
+ except Exception as e:
+ pytest.fail(f"Error occurred: {e}")
-# # test_completion_hf_stream()
+test_completion_hf_stream()
# def test_completion_hf_stream_bad_key():
# try:
@@ -680,7 +705,7 @@ def test_completion_sagemaker_stream():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
-test_completion_sagemaker_stream()
+# test_completion_sagemaker_stream()
# test on openai completion call
def test_openai_text_completion_call():
diff --git a/litellm/utils.py b/litellm/utils.py
index 256fe061d..f9a286bf6 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -2916,12 +2916,14 @@ class CustomStreamWrapper:
print_verbose(f"data json: {data_json}")
if "token" in data_json and "text" in data_json["token"]:
text = data_json["token"]["text"]
- if "meta-llama/Llama-2" in self.model: #clean eos tokens like from the returned output text
- if any(token in text for token in llama_2_special_tokens):
- text = text.replace("", "").replace("", "")
if data_json.get("details", False) and data_json["details"].get("finish_reason", False):
is_finished = True
finish_reason = data_json["details"]["finish_reason"]
+ elif data_json.get("generated_text", False): # if full generated text exists, then stream is complete
+ text = "" # don't return the final bos token
+ is_finished = True
+ finish_reason = "stop"
+
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
elif "error" in chunk:
raise ValueError(chunk)
diff --git a/pyproject.toml b/pyproject.toml
index 62daf0d5a..7f99f8375 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
-version = "0.1.799"
+version = "0.1.800"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"