forked from phoenix/litellm-mirror
add mistral prompt templating
This commit is contained in:
parent
82c642f78d
commit
e8ec3e8795
6 changed files with 85 additions and 43 deletions
Binary file not shown.
|
@ -22,6 +22,28 @@ def llama_2_chat_pt(messages):
|
||||||
)
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
def mistral_instruct_pt(messages):
|
||||||
|
prompt = custom_prompt(
|
||||||
|
initial_prompt_value="<s>",
|
||||||
|
role_dict={
|
||||||
|
"system": {
|
||||||
|
"pre_message": "[INST]",
|
||||||
|
"post_message": "[/INST]"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"pre_message": "[INST]",
|
||||||
|
"post_message": "[/INST]"
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"pre_message": "[INST]",
|
||||||
|
"post_message": "[/INST]"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
final_prompt_value="</s>",
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
|
||||||
def falcon_instruct_pt(messages):
|
def falcon_instruct_pt(messages):
|
||||||
prompt = ""
|
prompt = ""
|
||||||
|
@ -116,4 +138,6 @@ def prompt_factory(model: str, messages: list):
|
||||||
return phind_codellama_pt(messages=messages)
|
return phind_codellama_pt(messages=messages)
|
||||||
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
|
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
|
||||||
return llama_2_chat_pt(messages=messages)
|
return llama_2_chat_pt(messages=messages)
|
||||||
|
elif "mistralai/mistral" in model and "instruct" in model:
|
||||||
|
return mistral_instruct_pt(messages=messages)
|
||||||
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
|
@ -199,9 +199,9 @@ def test_get_hf_task_for_model():
|
||||||
# def hf_test_completion_tgi():
|
# def hf_test_completion_tgi():
|
||||||
# try:
|
# try:
|
||||||
# response = litellm.completion(
|
# response = litellm.completion(
|
||||||
# model="huggingface/glaiveai/glaive-coder-7b",
|
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
# messages=[{ "content": "Hello, how are you?","role": "user"}],
|
# messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
# api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud",
|
# api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud",
|
||||||
# )
|
# )
|
||||||
# # Add any assertions here to check the response
|
# # Add any assertions here to check the response
|
||||||
# print(response)
|
# print(response)
|
||||||
|
@ -646,16 +646,7 @@ def test_completion_azure_deployment_id():
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_completion_anthropic_litellm_proxy()
|
# test_completion_anthropic_litellm_proxy()
|
||||||
# def test_hf_conversational_task():
|
|
||||||
# try:
|
|
||||||
# messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
|
|
||||||
# # e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
|
||||||
# response = completion(model="huggingface/facebook/blenderbot-400M-distill", messages=messages, task="conversational")
|
|
||||||
# print(f"response: {response}")
|
|
||||||
# except Exception as e:
|
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
# test_hf_conversational_task()
|
|
||||||
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
||||||
|
|
||||||
# def test_completion_replicate_llama_2():
|
# def test_completion_replicate_llama_2():
|
||||||
|
@ -792,7 +783,7 @@ def test_completion_bedrock_claude():
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
test_completion_bedrock_claude()
|
# test_completion_bedrock_claude()
|
||||||
|
|
||||||
def test_completion_bedrock_claude_stream():
|
def test_completion_bedrock_claude_stream():
|
||||||
print("calling claude")
|
print("calling claude")
|
||||||
|
|
|
@ -314,8 +314,9 @@ def test_completion_cohere_stream_bad_key():
|
||||||
|
|
||||||
# test_completion_nlp_cloud_bad_key()
|
# test_completion_nlp_cloud_bad_key()
|
||||||
|
|
||||||
# def test_completion_hf_stream():
|
def test_completion_hf_stream():
|
||||||
# try:
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
# messages = [
|
# messages = [
|
||||||
# {
|
# {
|
||||||
# "content": "Hello! How are you today?",
|
# "content": "Hello! How are you today?",
|
||||||
|
@ -323,7 +324,7 @@ def test_completion_cohere_stream_bad_key():
|
||||||
# },
|
# },
|
||||||
# ]
|
# ]
|
||||||
# response = completion(
|
# response = completion(
|
||||||
# model="huggingface/meta-llama/Llama-2-7b-chat-hf", messages=messages, api_base="https://a8l9e3ucxinyl3oj.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
|
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
|
||||||
# )
|
# )
|
||||||
# complete_response = ""
|
# complete_response = ""
|
||||||
# # Add any assertions here to check the response
|
# # Add any assertions here to check the response
|
||||||
|
@ -334,13 +335,37 @@ def test_completion_cohere_stream_bad_key():
|
||||||
# complete_response += chunk
|
# complete_response += chunk
|
||||||
# if complete_response.strip() == "":
|
# if complete_response.strip() == "":
|
||||||
# raise Exception("Empty response received")
|
# raise Exception("Empty response received")
|
||||||
# print(f"completion_response: {complete_response}")
|
# completion_response_1 = complete_response
|
||||||
# except InvalidRequestError as e:
|
messages = [
|
||||||
# pass
|
{
|
||||||
# except Exception as e:
|
"content": "Hello! How are you today?",
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "I'm doing well, thank you for asking! I'm excited to be here and help you with any questions or concerns you may have. What can I assist you with today?</s>",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
for idx, chunk in enumerate(response):
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
if finished:
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
# print(f"completion_response_1: {completion_response_1}")
|
||||||
|
print(f"completion_response: {complete_response}")
|
||||||
|
except InvalidRequestError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# # test_completion_hf_stream()
|
test_completion_hf_stream()
|
||||||
|
|
||||||
# def test_completion_hf_stream_bad_key():
|
# def test_completion_hf_stream_bad_key():
|
||||||
# try:
|
# try:
|
||||||
|
@ -680,7 +705,7 @@ def test_completion_sagemaker_stream():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_completion_sagemaker_stream()
|
# test_completion_sagemaker_stream()
|
||||||
|
|
||||||
# test on openai completion call
|
# test on openai completion call
|
||||||
def test_openai_text_completion_call():
|
def test_openai_text_completion_call():
|
||||||
|
|
|
@ -2916,12 +2916,14 @@ class CustomStreamWrapper:
|
||||||
print_verbose(f"data json: {data_json}")
|
print_verbose(f"data json: {data_json}")
|
||||||
if "token" in data_json and "text" in data_json["token"]:
|
if "token" in data_json and "text" in data_json["token"]:
|
||||||
text = data_json["token"]["text"]
|
text = data_json["token"]["text"]
|
||||||
if "meta-llama/Llama-2" in self.model: #clean eos tokens like </s> from the returned output text
|
|
||||||
if any(token in text for token in llama_2_special_tokens):
|
|
||||||
text = text.replace("<s>", "").replace("</s>", "")
|
|
||||||
if data_json.get("details", False) and data_json["details"].get("finish_reason", False):
|
if data_json.get("details", False) and data_json["details"].get("finish_reason", False):
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = data_json["details"]["finish_reason"]
|
finish_reason = data_json["details"]["finish_reason"]
|
||||||
|
elif data_json.get("generated_text", False): # if full generated text exists, then stream is complete
|
||||||
|
text = "" # don't return the final bos token
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
||||||
elif "error" in chunk:
|
elif "error" in chunk:
|
||||||
raise ValueError(chunk)
|
raise ValueError(chunk)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.799"
|
version = "0.1.800"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue