refactor(huggingface_restapi.py): moving async completion + streaming to real async calls

This commit is contained in:
Krrish Dholakia 2023-11-15 15:14:13 -08:00
parent 77394e7987
commit 1a705bfbcb
5 changed files with 464 additions and 365 deletions

View file

@ -19,7 +19,7 @@ telemetry = True
max_tokens = 256 # OpenAI Defaults max_tokens = 256 # OpenAI Defaults
drop_params = False drop_params = False
retry = True retry = True
request_timeout: float = 6000 request_timeout: Optional[float] = None
api_key: Optional[str] = None api_key: Optional[str] = None
openai_key: Optional[str] = None openai_key: Optional[str] = None
azure_key: Optional[str] = None azure_key: Optional[str] = None

View file

@ -3,6 +3,7 @@ import os, copy, types
import json import json
from enum import Enum from enum import Enum
import httpx, requests import httpx, requests
from .base import BaseLLM
import time import time
import litellm import litellm
from typing import Callable, Dict, List, Any from typing import Callable, Dict, List, Any
@ -67,19 +68,6 @@ class HuggingfaceConfig():
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None} and v is not None}
def validate_environment(api_key, headers):
default_headers = {
"content-type": "application/json",
}
if api_key and headers is None:
default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = default_headers
elif headers:
headers=headers
else:
headers = default_headers
return headers
def output_parser(generated_text: str): def output_parser(generated_text: str):
""" """
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
@ -94,8 +82,6 @@ def output_parser(generated_text: str):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text return generated_text
tgi_models_cache = None tgi_models_cache = None
conv_models_cache = None conv_models_cache = None
def read_tgi_conv_models(): def read_tgi_conv_models():
@ -144,7 +130,106 @@ def get_hf_task_for_model(model):
else: else:
return "text-generation-inference" # default to tgi return "text-generation-inference" # default to tgi
def completion( class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
def __init__(self) -> None:
super().__init__()
def validate_environment(self, api_key, headers):
default_headers = {
"content-type": "application/json",
}
if api_key and headers is None:
default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = default_headers
elif headers:
headers=headers
else:
headers = default_headers
return headers
def convert_to_model_response_object(self,
completion_response,
model_response,
task,
optional_params,
encoding,
input_text,
model):
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"] # type: ignore
elif task == "text-generation-inference":
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = output_parser(completion_response[0]["generated_text"])
## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprob = sum_logprob
if "best_of" in optional_params and optional_params["best_of"] > 1:
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]:
choices_list = []
for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob)
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
else:
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = output_parser(completion_response[0]["generated_text"])
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = len(
encoding.encode(input_text)
) ##[TODO] use the llama2 tokenizer here
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
if output_text is not None and len(output_text) > 0:
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the llama2 tokenizer here
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
completion_tokens = 0
model_response["created"] = time.time()
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
model_response._hidden_params["original_response"] = completion_response
return model_response
def completion(self,
model: str, model: str,
messages: list, messages: list,
api_base: Optional[str], api_base: Optional[str],
@ -155,13 +240,15 @@ def completion(
api_key, api_key,
logging_obj, logging_obj,
custom_prompt_dict={}, custom_prompt_dict={},
acompletion: bool = False,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
super().completion()
exception_mapping_worked = False exception_mapping_worked = False
try: try:
headers = validate_environment(api_key, headers) headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model) task = get_hf_task_for_model(model)
print_verbose(f"{model}, {task}") print_verbose(f"{model}, {task}")
completion_url = "" completion_url = ""
@ -255,9 +342,17 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=input_text, input=input_text,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url}, additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url, "acompletion": acompletion},
) )
## COMPLETION CALL ## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model)
else:
### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params)
### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post( response = requests.post(
completion_url, completion_url,
@ -266,6 +361,7 @@ def completion(
stream=optional_params["stream"] stream=optional_params["stream"]
) )
return response.iter_lines() return response.iter_lines()
### SYNC COMPLETION
else: else:
response = requests.post( response = requests.post(
completion_url, completion_url,
@ -273,7 +369,6 @@ def completion(
data=json.dumps(data) data=json.dumps(data)
) )
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten) ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
is_streamed = False is_streamed = False
if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream": if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream":
@ -317,78 +412,16 @@ def completion(
message=completion_response["error"], message=completion_response["error"],
status_code=response.status_code, status_code=response.status_code,
) )
else:
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"] # type: ignore
elif task == "text-generation-inference":
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = output_parser(completion_response[0]["generated_text"])
## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprob = sum_logprob
if "best_of" in optional_params and optional_params["best_of"] > 1:
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]:
choices_list = []
for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob)
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
else:
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = output_parser(completion_response[0]["generated_text"])
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = len(
encoding.encode(input_text)
) ##[TODO] use the llama2 tokenizer here
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
print_verbose(f'output: {model_response["choices"][0]["message"]}')
output_text = model_response["choices"][0]["message"].get("content", "")
if output_text is not None and len(output_text) > 0:
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the llama2 tokenizer here
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
completion_tokens = 0
model_response["created"] = time.time() return self.convert_to_model_response_object(
model_response["model"] = model completion_response=completion_response,
usage = Usage( model_response=model_response,
prompt_tokens=prompt_tokens, task=task,
completion_tokens=completion_tokens, optional_params=optional_params,
total_tokens=prompt_tokens + completion_tokens encoding=encoding,
input_text=input_text,
model=model
) )
model_response.usage = usage
model_response._hidden_params["original_response"] = completion_response
return model_response
except HuggingfaceError as e: except HuggingfaceError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
@ -399,8 +432,65 @@ def completion(
import traceback import traceback
raise HuggingfaceError(status_code=500, message=traceback.format_exc()) raise HuggingfaceError(status_code=500, message=traceback.format_exc())
async def acompletion(self,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
task: str,
encoding: Any,
input_text: str,
model: str,
optional_params: dict):
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
try:
response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
def embedding( ## RESPONSE OBJECT
return self.convert_to_model_response_object(completion_response=response_json,
model_response=model_response,
task=task,
encoding=encoding,
input_text=input_text,
model=model,
optional_params=optional_params)
except Exception as e:
if isinstance(e,httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
elif response and hasattr(response, "text"):
raise HuggingfaceError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
raise HuggingfaceError(status_code=500, message=f"{str(e)}")
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str):
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
async with client.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message="An error occurred while streaming")
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self,
model: str, model: str,
input: list, input: list,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -409,7 +499,8 @@ def embedding(
model_response=None, model_response=None,
encoding=None, encoding=None,
): ):
headers = validate_environment(api_key, headers=None) super().embedding()
headers = self.validate_environment(api_key, headers=None)
# print_verbose(f"{model}, {task}") # print_verbose(f"{model}, {task}")
embed_url = "" embed_url = ""
if "https" in model: if "https" in model:

View file

@ -53,6 +53,7 @@ from .llms import (
maritalk) maritalk)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion from .llms.azure import AzureChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -77,6 +78,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion() openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion() openai_text_completions = OpenAITextCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
huggingface = Huggingface()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
class LiteLLM: class LiteLLM:
@ -165,7 +167,8 @@ async def acompletion(*args, **kwargs):
if (custom_llm_provider == "openai" if (custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai" or custom_llm_provider == "custom_openai"
or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all. or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False): if kwargs.get("stream", False):
response = completion(*args, **kwargs) response = completion(*args, **kwargs)
else: else:
@ -862,7 +865,7 @@ def completion(
custom_prompt_dict custom_prompt_dict
or litellm.custom_prompt_dict or litellm.custom_prompt_dict
) )
model_response = huggingface_restapi.completion( model_response = huggingface.completion(
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, # type: ignore api_base=api_base, # type: ignore
@ -874,10 +877,11 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
api_key=huggingface_key, api_key=huggingface_key,
acompletion=acompletion,
logging_obj=logging, logging_obj=logging,
custom_prompt_dict=custom_prompt_dict custom_prompt_dict=custom_prompt_dict
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True and acompletion is False:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model, custom_llm_provider="huggingface", logging_obj=logging model_response, model, custom_llm_provider="huggingface", logging_obj=logging

View file

@ -25,11 +25,12 @@ def test_sync_response():
def test_async_response(): def test_async_response():
import asyncio import asyncio
litellm.set_verbose = True
async def test_get_response(): async def test_get_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="command-nightly", messages=messages) response = await acompletion(model="huggingface/HuggingFaceH4/zephyr-7b-beta", messages=messages)
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
@ -44,7 +45,7 @@ def test_get_response_streaming():
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
litellm.set_verbose = True litellm.set_verbose = True
response = await acompletion(model="command-nightly", messages=messages, stream=True) response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(type(response)) print(type(response))
import inspect import inspect
@ -67,15 +68,16 @@ def test_get_response_streaming():
asyncio.run(test_async_call()) asyncio.run(test_async_call())
test_get_response_streaming() # test_get_response_streaming()
def test_get_response_non_openai_streaming(): def test_get_response_non_openai_streaming():
import asyncio import asyncio
litellm.set_verbose = True
async def test_async_call(): async def test_async_call():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="command-nightly", messages=messages, stream=True) response = await acompletion(model="huggingface/HuggingFaceH4/zephyr-7b-beta", messages=messages, stream=True)
print(type(response)) print(type(response))
import inspect import inspect
@ -98,4 +100,4 @@ def test_get_response_non_openai_streaming():
return response return response
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_non_openai_streaming() test_get_response_non_openai_streaming()

View file

@ -511,6 +511,8 @@ class Logging:
masked_headers = {k: v[:-40] + '*' * 40 if len(v) > 40 else v for k, v in headers.items()} masked_headers = {k: v[:-40] + '*' * 40 if len(v) > 40 else v for k, v in headers.items()}
formatted_headers = " ".join([f"-H '{k}: {v}'" for k, v in masked_headers.items()]) formatted_headers = " ".join([f"-H '{k}: {v}'" for k, v in masked_headers.items()])
print_verbose(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}")
curl_command = "\n\nPOST Request Sent from LiteLLM:\n" curl_command = "\n\nPOST Request Sent from LiteLLM:\n"
curl_command += "curl -X POST \\\n" curl_command += "curl -X POST \\\n"
curl_command += f"{api_base} \\\n" curl_command += f"{api_base} \\\n"
@ -4313,7 +4315,6 @@ class CustomStreamWrapper:
def handle_huggingface_chunk(self, chunk): def handle_huggingface_chunk(self, chunk):
try: try:
chunk = chunk.decode("utf-8")
text = "" text = ""
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
@ -4770,7 +4771,8 @@ class CustomStreamWrapper:
if (self.custom_llm_provider == "openai" if (self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure" or self.custom_llm_provider == "azure"
or self.custom_llm_provider == "custom_openai" or self.custom_llm_provider == "custom_openai"
or self.custom_llm_provider == "text-completion-openai"): or self.custom_llm_provider == "text-completion-openai"
or self.custom_llm_provider == "huggingface"):
async for chunk in self.completion_stream: async for chunk in self.completion_stream:
if chunk == "None" or chunk is None: if chunk == "None" or chunk is None:
raise Exception raise Exception