diff --git a/dist/litellm-0.1.696-py3-none-any.whl b/dist/litellm-0.1.696-py3-none-any.whl new file mode 100644 index 000000000..d20bbf5af Binary files /dev/null and b/dist/litellm-0.1.696-py3-none-any.whl differ diff --git a/dist/litellm-0.1.696.tar.gz b/dist/litellm-0.1.696.tar.gz new file mode 100644 index 000000000..530794fbf Binary files /dev/null and b/dist/litellm-0.1.696.tar.gz differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index db5819b91..b5c44d665 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index a7e041bf5..f824c8ca3 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -2,6 +2,7 @@ def default_pt(messages): return " ".join(message["content"] for message in messages) # Llama2 prompt template +llama_2_special_tokens = ["", ""] def llama_2_chat_pt(messages): prompt = custom_prompt( role_dict={ diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 8465658c4..535a0564f 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -111,8 +111,8 @@ def test_completion_with_litellm_call_id(): # try: # user_message = "write some code to find the sum of two numbers" # messages = [{ "content": user_message,"role": "user"}] -# api_base = "https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud" -# response = completion(model="togethercomputer/LLaMA-2-7B-32K", messages=messages, custom_llm_provider="huggingface", api_base=api_base, logger_fn=logger_fn) +# api_base = "https://a8l9e3ucxinyl3oj.us-east-1.aws.endpoints.huggingface.cloud" +# response = completion(model="huggingface/meta-llama/Llama-2-7b-chat-hf", messages=messages, api_base=api_base, logger_fn=logger_fn) # # Add any assertions here to check the response # print(response) # except Exception as e: diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 10d721478..10f772c25 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -214,6 +214,32 @@ def test_completion_cohere_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") +# def test_completion_hf_stream(): +# try: +# messages = [ +# { +# "content": "Hello! How are you today?", +# "role": "user" +# }, +# ] +# response = completion( +# model="huggingface/meta-llama/Llama-2-7b-chat-hf", messages=messages, api_base="https://a8l9e3ucxinyl3oj.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000 +# ) +# complete_response = "" +# # Add any assertions here to check the response +# for idx, chunk in enumerate(response): +# chunk, finished = streaming_format_tests(idx, chunk) +# if finished: +# break +# complete_response += chunk +# if complete_response.strip() == "": +# raise Exception("Empty response received") +# print(f"completion_response: {complete_response}") +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") + +# test_completion_hf_stream() + def test_completion_claude_stream(): try: messages = [ diff --git a/litellm/utils.py b/litellm/utils.py index bd91583e1..648bb63e7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -36,7 +36,7 @@ from .exceptions import ( ) from typing import cast, List, Dict, Union, Optional from .caching import Cache - +from .llms.prompt_templates.factory import llama_2_special_tokens ####### ENVIRONMENT VARIABLES #################### dotenv.load_dotenv() # Loading env variables using dotenv @@ -2518,7 +2518,11 @@ class CustomStreamWrapper: if chunk.startswith("data:"): data_json = json.loads(chunk[5:]) if "token" in data_json and "text" in data_json["token"]: - return data_json["token"]["text"] + text = data_json["token"]["text"] + if "meta-llama/Llama-2" in self.model: #clean eos tokens like from the returned output text + if any(token in text for token in llama_2_special_tokens): + text = text.replace("", "").replace("", "") + return text else: return "" return "" diff --git a/pyproject.toml b/pyproject.toml index 5fbe4b0e7..6df802fb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.696" +version = "0.1.697" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"