forked from phoenix/litellm-mirror
huggingface conversational task support
This commit is contained in:
parent
1913d36e05
commit
5b6b9a9fab
7 changed files with 75 additions and 40 deletions
Binary file not shown.
Binary file not shown.
|
@ -40,7 +40,9 @@ def completion(
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key)
|
||||||
|
task = optional_params.pop("task")
|
||||||
completion_url = ""
|
completion_url = ""
|
||||||
|
input_text = None
|
||||||
if "https" in model:
|
if "https" in model:
|
||||||
completion_url = model
|
completion_url = model
|
||||||
elif api_base:
|
elif api_base:
|
||||||
|
@ -49,37 +51,62 @@ def completion(
|
||||||
completion_url = os.getenv("HF_API_BASE", "")
|
completion_url = os.getenv("HF_API_BASE", "")
|
||||||
else:
|
else:
|
||||||
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||||
if model in custom_prompt_dict:
|
|
||||||
# check if the model has a registered custom prompt
|
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
|
||||||
prompt = custom_prompt(
|
|
||||||
role_dict=model_prompt_details["roles"],
|
|
||||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
|
||||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
|
||||||
messages=messages
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
|
||||||
### MAP INPUT PARAMS
|
### MAP INPUT PARAMS
|
||||||
if "https://api-inference.huggingface.co/models" in completion_url:
|
if task == "conversational":
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
inference_params.pop("details")
|
inference_params.pop("details")
|
||||||
|
past_user_inputs = []
|
||||||
|
generated_responses = []
|
||||||
|
text = ""
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "user":
|
||||||
|
if text != "":
|
||||||
|
past_user_inputs.append(text)
|
||||||
|
text = message["content"]
|
||||||
|
elif message["role"] == "assistant" or message["role"] == "system":
|
||||||
|
generated_responses.append(message["content"])
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": {
|
||||||
"parameters": inference_params,
|
"text": text,
|
||||||
"stream": True if "stream" in inference_params and inference_params["stream"] == True else False,
|
"past_user_inputs": past_user_inputs,
|
||||||
}
|
"generated_responses": generated_responses
|
||||||
else:
|
},
|
||||||
data = {
|
"parameters": inference_params
|
||||||
"inputs": prompt,
|
|
||||||
"parameters": optional_params,
|
|
||||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
|
|
||||||
}
|
}
|
||||||
|
input_text = "".join(message["content"] for message in messages)
|
||||||
|
elif task == "text-generation-inference":
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
if "https://api-inference.huggingface.co/models" in completion_url:
|
||||||
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
inference_params.pop("details")
|
||||||
|
data = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": inference_params,
|
||||||
|
"stream": True if "stream" in inference_params and inference_params["stream"] == True else False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
data = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": optional_params,
|
||||||
|
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
|
||||||
|
}
|
||||||
|
input_text = prompt
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=input_text,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data, "task": task},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
@ -98,10 +125,10 @@ def completion(
|
||||||
)
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=prompt,
|
input=input_text,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
original_response=response.text,
|
original_response=response.text,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data, "task": task},
|
||||||
)
|
)
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
try:
|
try:
|
||||||
|
@ -119,19 +146,23 @@ def completion(
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_response["choices"][0]["message"][
|
if task == "conversational":
|
||||||
"content"
|
model_response["choices"][0]["message"][
|
||||||
] = completion_response[0]["generated_text"]
|
"content"
|
||||||
|
] = completion_response["generated_text"]
|
||||||
## GETTING LOGPROBS
|
elif task == "text-generation-inference":
|
||||||
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
model_response["choices"][0]["message"][
|
||||||
sum_logprob = 0
|
"content"
|
||||||
for token in completion_response[0]["details"]["tokens"]:
|
] = completion_response[0]["generated_text"]
|
||||||
sum_logprob += token["logprob"]
|
## GETTING LOGPROBS
|
||||||
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
|
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in completion_response[0]["details"]["tokens"]:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = len(
|
prompt_tokens = len(
|
||||||
encoding.encode(prompt)
|
encoding.encode(input_text)
|
||||||
) ##[TODO] use the llama2 tokenizer here
|
) ##[TODO] use the llama2 tokenizer here
|
||||||
completion_tokens = len(
|
completion_tokens = len(
|
||||||
encoding.encode(model_response["choices"][0]["message"]["content"])
|
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||||
|
|
|
@ -109,8 +109,8 @@ def completion(
|
||||||
use_client=False,
|
use_client=False,
|
||||||
id=None, # this is an optional param to tag individual completion calls
|
id=None, # this is an optional param to tag individual completion calls
|
||||||
# model specific optional params
|
# model specific optional params
|
||||||
# used by text-bison only
|
top_k=40,# used by text-bison only
|
||||||
top_k=40,
|
task: Optional[str]="text-generation-inference", # used by huggingface inference endpoints
|
||||||
request_timeout=0, # unused var for old version of OpenAI API
|
request_timeout=0, # unused var for old version of OpenAI API
|
||||||
fallbacks=[],
|
fallbacks=[],
|
||||||
caching = False,
|
caching = False,
|
||||||
|
@ -154,6 +154,7 @@ def completion(
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
task=task
|
||||||
)
|
)
|
||||||
# For logging - save the values of the litellm-specific params passed in
|
# For logging - save the values of the litellm-specific params passed in
|
||||||
litellm_params = get_litellm_params(
|
litellm_params = get_litellm_params(
|
||||||
|
|
|
@ -119,7 +119,8 @@ def test_completion_claude_stream():
|
||||||
# try:
|
# try:
|
||||||
# user_message = "write some code to find the sum of two numbers"
|
# user_message = "write some code to find the sum of two numbers"
|
||||||
# messages = [{ "content": user_message,"role": "user"}]
|
# messages = [{ "content": user_message,"role": "user"}]
|
||||||
# response = completion(model="stabilityai/stablecode-instruct-alpha-3b", messages=messages, custom_llm_provider="huggingface", logger_fn=logger_fn)
|
# api_base = "https://wyh9bqfgj2r1klv5.us-east-1.aws.endpoints.huggingface.cloud"
|
||||||
|
# response = completion(model="facebook/blenderbot-400M-distill", messages=messages, custom_llm_provider="huggingface", task="conversational", api_base=api_base, logger_fn=logger_fn)
|
||||||
# # Add any assertions here to check the response
|
# # Add any assertions here to check the response
|
||||||
# print(response)
|
# print(response)
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
|
|
|
@ -788,6 +788,7 @@ def get_optional_params( # use the openai defaults
|
||||||
model=None,
|
model=None,
|
||||||
custom_llm_provider="",
|
custom_llm_provider="",
|
||||||
top_k=40,
|
top_k=40,
|
||||||
|
task=None
|
||||||
):
|
):
|
||||||
optional_params = {}
|
optional_params = {}
|
||||||
if model in litellm.anthropic_models:
|
if model in litellm.anthropic_models:
|
||||||
|
@ -882,6 +883,7 @@ def get_optional_params( # use the openai defaults
|
||||||
if presence_penalty != 0:
|
if presence_penalty != 0:
|
||||||
optional_params["repetition_penalty"] = presence_penalty
|
optional_params["repetition_penalty"] = presence_penalty
|
||||||
optional_params["details"] = True
|
optional_params["details"] = True
|
||||||
|
optional_params["task"] = task
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
if "llama-2" in model:
|
if "llama-2" in model:
|
||||||
# llama-2 models on sagemaker support the following args
|
# llama-2 models on sagemaker support the following args
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.618"
|
version = "0.1.619"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue