handle llama 2 eos tokens in streaming

This commit is contained in:
Krrish Dholakia 2023-09-18 13:44:19 -07:00
parent d81c75a0b3
commit 633e36de42
8 changed files with 36 additions and 5 deletions

BIN
dist/litellm-0.1.696-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.696.tar.gz vendored Normal file

Binary file not shown.

View file

@ -2,6 +2,7 @@ def default_pt(messages):
return " ".join(message["content"] for message in messages) return " ".join(message["content"] for message in messages)
# Llama2 prompt template # Llama2 prompt template
llama_2_special_tokens = ["<s>", "</s>"]
def llama_2_chat_pt(messages): def llama_2_chat_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(
role_dict={ role_dict={

View file

@ -111,8 +111,8 @@ def test_completion_with_litellm_call_id():
# try: # try:
# user_message = "write some code to find the sum of two numbers" # user_message = "write some code to find the sum of two numbers"
# messages = [{ "content": user_message,"role": "user"}] # messages = [{ "content": user_message,"role": "user"}]
# api_base = "https://ecd4sb5n09bo4ei2.us-east-1.aws.endpoints.huggingface.cloud" # api_base = "https://a8l9e3ucxinyl3oj.us-east-1.aws.endpoints.huggingface.cloud"
# response = completion(model="togethercomputer/LLaMA-2-7B-32K", messages=messages, custom_llm_provider="huggingface", api_base=api_base, logger_fn=logger_fn) # response = completion(model="huggingface/meta-llama/Llama-2-7b-chat-hf", messages=messages, api_base=api_base, logger_fn=logger_fn)
# # Add any assertions here to check the response # # Add any assertions here to check the response
# print(response) # print(response)
# except Exception as e: # except Exception as e:

View file

@ -214,6 +214,32 @@ def test_completion_cohere_stream():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# def test_completion_hf_stream():
# try:
# messages = [
# {
# "content": "Hello! How are you today?",
# "role": "user"
# },
# ]
# response = completion(
# model="huggingface/meta-llama/Llama-2-7b-chat-hf", messages=messages, api_base="https://a8l9e3ucxinyl3oj.us-east-1.aws.endpoints.huggingface.cloud", stream=True, max_tokens=1000
# )
# complete_response = ""
# # Add any assertions here to check the response
# for idx, chunk in enumerate(response):
# chunk, finished = streaming_format_tests(idx, chunk)
# if finished:
# break
# complete_response += chunk
# if complete_response.strip() == "":
# raise Exception("Empty response received")
# print(f"completion_response: {complete_response}")
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_hf_stream()
def test_completion_claude_stream(): def test_completion_claude_stream():
try: try:
messages = [ messages = [

View file

@ -36,7 +36,7 @@ from .exceptions import (
) )
from typing import cast, List, Dict, Union, Optional from typing import cast, List, Dict, Union, Optional
from .caching import Cache from .caching import Cache
from .llms.prompt_templates.factory import llama_2_special_tokens
####### ENVIRONMENT VARIABLES #################### ####### ENVIRONMENT VARIABLES ####################
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
@ -2518,7 +2518,11 @@ class CustomStreamWrapper:
if chunk.startswith("data:"): if chunk.startswith("data:"):
data_json = json.loads(chunk[5:]) data_json = json.loads(chunk[5:])
if "token" in data_json and "text" in data_json["token"]: if "token" in data_json and "text" in data_json["token"]:
return data_json["token"]["text"] text = data_json["token"]["text"]
if "meta-llama/Llama-2" in self.model: #clean eos tokens like </s> from the returned output text
if any(token in text for token in llama_2_special_tokens):
text = text.replace("<s>", "").replace("</s>", "")
return text
else: else:
return "" return ""
return "" return ""

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.696" version = "0.1.697"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"