huggingface conversational task support

This commit is contained in:
Krrish Dholakia 2023-09-13 13:44:46 -07:00
parent 1913d36e05
commit 5b6b9a9fab
7 changed files with 75 additions and 40 deletions

View file

@ -40,7 +40,9 @@ def completion(
logger_fn=None,
):
headers = validate_environment(api_key)
task = optional_params.pop("task")
completion_url = ""
input_text = None
if "https" in model:
completion_url = model
elif api_base:
@ -49,6 +51,31 @@ def completion(
completion_url = os.getenv("HF_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
### MAP INPUT PARAMS
if task == "conversational":
inference_params = copy.deepcopy(optional_params)
inference_params.pop("details")
past_user_inputs = []
generated_responses = []
text = ""
for message in messages:
if message["role"] == "user":
if text != "":
past_user_inputs.append(text)
text = message["content"]
elif message["role"] == "assistant" or message["role"] == "system":
generated_responses.append(message["content"])
data = {
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses
},
"parameters": inference_params
}
input_text = "".join(message["content"] for message in messages)
elif task == "text-generation-inference":
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
@ -60,7 +87,6 @@ def completion(
)
else:
prompt = prompt_factory(model=model, messages=messages)
### MAP INPUT PARAMS
if "https://api-inference.huggingface.co/models" in completion_url:
inference_params = copy.deepcopy(optional_params)
inference_params.pop("details")
@ -75,11 +101,12 @@ def completion(
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=prompt,
input=input_text,
api_key=api_key,
additional_args={"complete_input_dict": data},
additional_args={"complete_input_dict": data, "task": task},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
@ -98,10 +125,10 @@ def completion(
)
## LOGGING
logging_obj.post_call(
input=prompt,
input=input_text,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
additional_args={"complete_input_dict": data, "task": task},
)
## RESPONSE OBJECT
try:
@ -119,10 +146,14 @@ def completion(
status_code=response.status_code,
)
else:
if task == "conversational":
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
elif task == "text-generation-inference":
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
sum_logprob = 0
@ -131,7 +162,7 @@ def completion(
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
encoding.encode(input_text)
) ##[TODO] use the llama2 tokenizer here
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])

View file

@ -109,8 +109,8 @@ def completion(
use_client=False,
id=None, # this is an optional param to tag individual completion calls
# model specific optional params
# used by text-bison only
top_k=40,
top_k=40,# used by text-bison only
task: Optional[str]="text-generation-inference", # used by huggingface inference endpoints
request_timeout=0, # unused var for old version of OpenAI API
fallbacks=[],
caching = False,
@ -154,6 +154,7 @@ def completion(
model=model,
custom_llm_provider=custom_llm_provider,
top_k=top_k,
task=task
)
# For logging - save the values of the litellm-specific params passed in
litellm_params = get_litellm_params(

View file

@ -119,7 +119,8 @@ def test_completion_claude_stream():
# try:
# user_message = "write some code to find the sum of two numbers"
# messages = [{ "content": user_message,"role": "user"}]
# response = completion(model="stabilityai/stablecode-instruct-alpha-3b", messages=messages, custom_llm_provider="huggingface", logger_fn=logger_fn)
# api_base = "https://wyh9bqfgj2r1klv5.us-east-1.aws.endpoints.huggingface.cloud"
# response = completion(model="facebook/blenderbot-400M-distill", messages=messages, custom_llm_provider="huggingface", task="conversational", api_base=api_base, logger_fn=logger_fn)
# # Add any assertions here to check the response
# print(response)
# except Exception as e:

View file

@ -788,6 +788,7 @@ def get_optional_params( # use the openai defaults
model=None,
custom_llm_provider="",
top_k=40,
task=None
):
optional_params = {}
if model in litellm.anthropic_models:
@ -882,6 +883,7 @@ def get_optional_params( # use the openai defaults
if presence_penalty != 0:
optional_params["repetition_penalty"] = presence_penalty
optional_params["details"] = True
optional_params["task"] = task
elif custom_llm_provider == "sagemaker":
if "llama-2" in model:
# llama-2 models on sagemaker support the following args

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.618"
version = "0.1.619"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"