From e8ec3e87959af79a09bd838d5fbc01aa566a85b5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 29 Sep 2023 21:41:19 -0700 Subject: [PATCH] add mistral prompt templating --- litellm/__pycache__/utils.cpython-311.pyc | Bin 137051 -> 136629 bytes litellm/llms/prompt_templates/factory.py | 24 +++++++ litellm/tests/test_completion.py | 15 +--- litellm/tests/test_streaming.py | 79 ++++++++++++++-------- litellm/utils.py | 8 ++- pyproject.toml | 2 +- 6 files changed, 85 insertions(+), 43 deletions(-) diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index da70c6bd25172ba75ffd3c0b33935789a20b2572..3a881f73d32ff80c5ba9d115ac21f067f103ddd0 100644 GIT binary patch delta 1608 zcmZuveQZ-z6o0?hzOHm@w+^w5kSL8mcWR@H6{>IBFKk`&al7&^4`9eSJ&3pj57WjSl%xTx_O4Q}>z~aa$i}$1lh+?BSxmUgO-)a2?he z6yNm3?J+%Oi`mW0-M0CfnH7WblNoi_9|;v?ZASm1-%BgDV|=-YL({}7%GZLVXGJI+ z5=)~YGZ+okMn#p%ieHs$aa%?HG-|6y0sT>r+^p%WYUdJ5MUw6`p@_!S!%HtVV5~j| zlxRi{#TqaXnRKiHOL8&?*TuKQKU;jzxRhDZo>`I79PUzzH=5{)7H@ z-%D-5cHgXoJK^r|7SNFGLw+7Qx-I9DcS5^&LQ1nKa$bJI{zUP7!hJ=yHR3WH(R59# ziBVlMN@@KrWXm52J`{;RmC*(Lj`JHEzrk>mjSr%dotWv^$m+1AmfLn>GSE#Ino#U` z%B;o~`Pwer1fB?IH)9GK<(6i=r=vycdvO-{nl9{9b{2Y3lKc0Teld05;LdGuT{d?ZO*~+YDA;tAl04 z+TkG=EF*T-dlsg)ie4I%r7;POFDAtP|xZp}C!yfkN8WiAu*%-qJ@GJ5l2+<#;5+ zX?60XCseyUv?ls?BpeJcT^$mBT9(9YW4t2lkzXe9VOl{YlTDsbEuSRH1DnXXNffw> zGGnreAIbNxVy1ygI&~dgm`ttR>a{$k^WCUW$w%E7?wHo68F2$Ej9DrkB)8o_lYufS z@4+|fcfAKCu5xBEOqJOwxNVp!eSgCP;|-NtE63i&d>D&W9G6?~AsgAxv4E64zfH{L zcpe?Tk31}*)AzC1wUAQ`A5+-_HK)0h+l!H|YECl{RrTV7Lhe$dj;Hy2VicdyYQ?Zy zcJ<;bDB-S$XmrhEA;Vrd=@H5tm`Wk&0R%`wpY$fX6l~9G;Li|Sx!E#?m7HD0uvWQE z4qCmmHW-bp425fgZ$(6Kb;Jy<5&^EyR`jH2`8pv&;ph{Nc%3VYD4^@3TqPVcoRlBw z`a-*7gNikeyynn%0`E}Vsn5hTdD5wWZ>!?>BlzJ6zE#23C3ybUl*LoS)}*Yt^o>zb tlJ;eh$SP|mv4PHK=tHnW-ptUuoMTJ_BU0M`;~gu1w|{lcv0S|x{{ljd!;Js{ delta 1968 zcmZux3v5$W7(U-^Z&y0oJz)>BwF7Z>8$$(_Ik&k*5wb^g3?vM3cHBDA?vc|%aJFtB zK0pX_V>#O-h!|u9LePMh=oG|6kVs@u*qqTohzyy)Kv6J8JpY}<`8a96e&>Jx|2aMT zc3QvLZFOw7+igs)*P0T%mQ^bq`yBXcgJreC*sH2Iy8+p@wW?+cDVh>8iKJGP<9qSx zRwRn%RybqK0TVfoE-fNxZan;g9gW0meWaq|u2#5cv(c>|uQKA)$~Gp&qp>V4vq6ilde z$*Mb0%D@B*vc=XAf>@(0XOA-huyx}!R>}BsX~f+U+$O%1xLLN#i@=zl=I(N}@JtfF zBcbU*aqnhK6nYzCnzQqL4NE*ee_?iht+(F0qLCLq6DhsYrQwYW>;h|H$CJzDr)PfW z%q4lR1ByzmSJyT3#ZB^Y}?B`QH zbkIR?Pr$2H2fQl+^Y|R<*O=rfYWgCt}dkO(VI}JV4MELW*h}R5b0ZF0XN01EvQ4UUiUWQ%s8a4+=dD> zzSB?cz_pnCUn%Q3f*}GXuND7MvcZUkWr4FQXBG8L;%TFY= zi0obz;Fj?9Vk55U_!x)AS;kScNb$wzs6exrco?g4jVzRaUm-B!8WD^_j10e%!YIT- z{?P<+;?!Yu$HuEw7-3^DOJ8{e??c{*8AtJ$njmdo(bpWs2-H-`*KdFEO8_Z)^D)e^ z;v;e76ozqH>^+UcNEN|;WMQ({(T`%Jh~a+ZTGA-nB;g)_-{mHrL2y_$J?QYRSmJFA zEN!T-uJ_b=d8X(cK$)5=4L0faudzBJIggAr8qNx8B0#fxfMoe%=NV+E1(M&?2hZRk z6?r1{JkFy)s6%kc$=x-CY$=>SggY(8qdv_;s8wf69;*)zp<6|kc;ym4lt0Cnk!H&# zD}sF8cNumQiu9(da6nxkWexiJYp8^}Nb(o-?|(upGNzD&kmbqeq=NWd34>8fJoVKc zO{=c+`u(0-uRDTQiL~ojXj@1&1TTx#*X7_UMB^{G+qRg*1RY}QFL-Gph4ssR<==Jm zIWl4sbn8XGVlSlQ^&9B0Q92{%W_{yLWLQun4&0J=qp~AL$xEZEIhAYF=Dvf!PUEW~ zs3U28hfi@z+jUBkWgZEtI(90Vw#h^Z4(vOlEHGO_lKx@W3sj{G@QHql@(_yki4n?Y zre?YyoGu2ZQ{!|ToaSpxk^fzQ(YVp(-(xb@G&J~(M4VC>&$97G8a@AJ&97_FmifGe zyi;oBjG6x_8)G9TdH+V@86sP3E&Ad}<(T!J8WmGU*#F~IL;mwvZHZnLuPnl!hIAM% diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index c8c423db2..d47a7486d 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -22,6 +22,28 @@ def llama_2_chat_pt(messages): ) return prompt +def mistral_instruct_pt(messages): + prompt = custom_prompt( + initial_prompt_value="", + role_dict={ + "system": { + "pre_message": "[INST]", + "post_message": "[/INST]" + }, + "user": { + "pre_message": "[INST]", + "post_message": "[/INST]" + }, + "assistant": { + "pre_message": "[INST]", + "post_message": "[/INST]" + } + }, + final_prompt_value="", + messages=messages + ) + return prompt + # Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 def falcon_instruct_pt(messages): prompt = "" @@ -116,4 +138,6 @@ def prompt_factory(model: str, messages: list): return phind_codellama_pt(messages=messages) elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model): return llama_2_chat_pt(messages=messages) + elif "mistralai/mistral" in model and "instruct" in model: + return mistral_instruct_pt(messages=messages) return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) \ No newline at end of file diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 760ffa2d9..0b55e868a 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -199,9 +199,9 @@ def test_get_hf_task_for_model(): # def hf_test_completion_tgi(): # try: # response = litellm.completion( -# model="huggingface/glaiveai/glaive-coder-7b", +# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", # messages=[{ "content": "Hello, how are you?","role": "user"}], -# api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud", +# api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", # ) # # Add any assertions here to check the response # print(response) @@ -646,16 +646,7 @@ def test_completion_azure_deployment_id(): # pytest.fail(f"Error occurred: {e}") # test_completion_anthropic_litellm_proxy() -# def test_hf_conversational_task(): -# try: -# messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}] -# # e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints -# response = completion(model="huggingface/facebook/blenderbot-400M-distill", messages=messages, task="conversational") -# print(f"response: {response}") -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") -# test_hf_conversational_task() # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect. # def test_completion_replicate_llama_2(): @@ -792,7 +783,7 @@ def test_completion_bedrock_claude(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_bedrock_claude() +# test_completion_bedrock_claude() def test_completion_bedrock_claude_stream(): print("calling claude") diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index b628e959d..f4a3db36e 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -314,33 +314,58 @@ def test_completion_cohere_stream_bad_key(): # test_completion_nlp_cloud_bad_key() -# def test_completion_hf_stream(): -# try: -# messages = [ -# { -# "content": "Hello! How are you today?", -# "role": "user" -# }, -# ] -# response = completion( -# model="huggingface/meta-llama/Llama-2-7b-chat-hf", messages=messages, api_base="https://a8l9e3ucxinyl3oj.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000 -# ) -# complete_response = "" -# # Add any assertions here to check the response -# for idx, chunk in enumerate(response): -# chunk, finished = streaming_format_tests(idx, chunk) -# if finished: -# break -# complete_response += chunk -# if complete_response.strip() == "": -# raise Exception("Empty response received") -# print(f"completion_response: {complete_response}") -# except InvalidRequestError as e: -# pass -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") +def test_completion_hf_stream(): + try: + litellm.set_verbose = True + # messages = [ + # { + # "content": "Hello! How are you today?", + # "role": "user" + # }, + # ] + # response = completion( + # model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000 + # ) + # complete_response = "" + # # Add any assertions here to check the response + # for idx, chunk in enumerate(response): + # chunk, finished = streaming_format_tests(idx, chunk) + # if finished: + # break + # complete_response += chunk + # if complete_response.strip() == "": + # raise Exception("Empty response received") + # completion_response_1 = complete_response + messages = [ + { + "content": "Hello! How are you today?", + "role": "user" + }, + { + "content": "I'm doing well, thank you for asking! I'm excited to be here and help you with any questions or concerns you may have. What can I assist you with today?", + "role": "assistant" + }, + ] + response = completion( + model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000 + ) + complete_response = "" + # Add any assertions here to check the response + for idx, chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + break + complete_response += chunk + if complete_response.strip() == "": + raise Exception("Empty response received") + # print(f"completion_response_1: {completion_response_1}") + print(f"completion_response: {complete_response}") + except InvalidRequestError as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") -# # test_completion_hf_stream() +test_completion_hf_stream() # def test_completion_hf_stream_bad_key(): # try: @@ -680,7 +705,7 @@ def test_completion_sagemaker_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_sagemaker_stream() +# test_completion_sagemaker_stream() # test on openai completion call def test_openai_text_completion_call(): diff --git a/litellm/utils.py b/litellm/utils.py index 256fe061d..f9a286bf6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2916,12 +2916,14 @@ class CustomStreamWrapper: print_verbose(f"data json: {data_json}") if "token" in data_json and "text" in data_json["token"]: text = data_json["token"]["text"] - if "meta-llama/Llama-2" in self.model: #clean eos tokens like from the returned output text - if any(token in text for token in llama_2_special_tokens): - text = text.replace("", "").replace("", "") if data_json.get("details", False) and data_json["details"].get("finish_reason", False): is_finished = True finish_reason = data_json["details"]["finish_reason"] + elif data_json.get("generated_text", False): # if full generated text exists, then stream is complete + text = "" # don't return the final bos token + is_finished = True + finish_reason = "stop" + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} elif "error" in chunk: raise ValueError(chunk) diff --git a/pyproject.toml b/pyproject.toml index 62daf0d5a..7f99f8375 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.799" +version = "0.1.800" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"