refactor(huggingface_restapi.py): moving async completion + streaming to real async calls

This commit is contained in:
Krrish Dholakia 2023-11-15 15:14:13 -08:00
parent 77394e7987
commit 1a705bfbcb
5 changed files with 464 additions and 365 deletions

View file

@ -53,6 +53,7 @@ from .llms import (
maritalk)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
import tiktoken
from concurrent.futures import ThreadPoolExecutor
@ -77,6 +78,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
azure_chat_completions = AzureChatCompletion()
huggingface = Huggingface()
####### COMPLETION ENDPOINTS ################
class LiteLLM:
@ -165,7 +167,8 @@ async def acompletion(*args, **kwargs):
if (custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all.
or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False):
response = completion(*args, **kwargs)
else:
@ -862,7 +865,7 @@ def completion(
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = huggingface_restapi.completion(
model_response = huggingface.completion(
model=model,
messages=messages,
api_base=api_base, # type: ignore
@ -874,10 +877,11 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
api_key=huggingface_key,
acompletion=acompletion,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] == True and acompletion is False:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="huggingface", logging_obj=logging