diff --git a/litellm/__init__.py b/litellm/__init__.py index fe5d4803d..747db753f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -19,7 +19,7 @@ telemetry = True max_tokens = 256 # OpenAI Defaults drop_params = False retry = True -request_timeout: float = 6000 +request_timeout: Optional[float] = None api_key: Optional[str] = None openai_key: Optional[str] = None azure_key: Optional[str] = None diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 251106828..f7379e97e 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -3,6 +3,7 @@ import os, copy, types import json from enum import Enum import httpx, requests +from .base import BaseLLM import time import litellm from typing import Callable, Dict, List, Any @@ -67,19 +68,6 @@ class HuggingfaceConfig(): and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) and v is not None} -def validate_environment(api_key, headers): - default_headers = { - "content-type": "application/json", - } - if api_key and headers is None: - default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens - headers = default_headers - elif headers: - headers=headers - else: - headers = default_headers - return headers - def output_parser(generated_text: str): """ Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. @@ -94,8 +82,6 @@ def output_parser(generated_text: str): generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] return generated_text - - tgi_models_cache = None conv_models_cache = None def read_tgi_conv_models(): @@ -144,365 +130,470 @@ def get_hf_task_for_model(model): else: return "text-generation-inference" # default to tgi -def completion( - model: str, - messages: list, - api_base: Optional[str], - headers: Optional[dict], - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - custom_prompt_dict={}, - optional_params=None, - litellm_params=None, - logger_fn=None, -): - exception_mapping_worked = False - try: - headers = validate_environment(api_key, headers) - task = get_hf_task_for_model(model) - print_verbose(f"{model}, {task}") - completion_url = "" - input_text = None - if "https" in model: - completion_url = model - elif api_base: - completion_url = api_base - elif "HF_API_BASE" in os.environ: - completion_url = os.getenv("HF_API_BASE", "") - elif "HUGGINGFACE_API_BASE" in os.environ: - completion_url = os.getenv("HUGGINGFACE_API_BASE", "") - else: - completion_url = f"https://api-inference.huggingface.co/models/{model}" +class Huggingface(BaseLLM): + _client_session: Optional[httpx.Client] = None + _aclient_session: Optional[httpx.AsyncClient] = None - ## Load Config - config=litellm.HuggingfaceConfig.get_config() - for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v + def __init__(self) -> None: + super().__init__() - ### MAP INPUT PARAMS - if task == "conversational": - inference_params = copy.deepcopy(optional_params) - inference_params.pop("details") - inference_params.pop("return_full_text") - past_user_inputs = [] - generated_responses = [] - text = "" - for message in messages: - if message["role"] == "user": - if text != "": - past_user_inputs.append(text) - text = message["content"] - elif message["role"] == "assistant" or message["role"] == "system": - generated_responses.append(message["content"]) - data = { - "inputs": { - "text": text, - "past_user_inputs": past_user_inputs, - "generated_responses": generated_responses - }, - "parameters": inference_params - } - input_text = "".join(message["content"] for message in messages) - elif task == "text-generation-inference": - # always send "details" and "return_full_text" as params - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", None), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - messages=messages - ) - else: - prompt = prompt_factory(model=model, messages=messages) - data = { - "inputs": prompt, - "parameters": optional_params, - "stream": True if "stream" in optional_params and optional_params["stream"] == True else False, - } - input_text = prompt - else: - # Non TGI and Conversational llms - # We need this branch, it removes 'details' and 'return_full_text' from params - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", {}), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - bos_token=model_prompt_details.get("bos_token", ""), - eos_token=model_prompt_details.get("eos_token", ""), - messages=messages, - ) - else: - prompt = prompt_factory(model=model, messages=messages) - inference_params = copy.deepcopy(optional_params) - inference_params.pop("details") - inference_params.pop("return_full_text") - data = { - "inputs": prompt, - "parameters": inference_params, - "stream": True if "stream" in optional_params and optional_params["stream"] == True else False, - } - input_text = prompt - ## LOGGING - logging_obj.pre_call( - input=input_text, - api_key=api_key, - additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url}, - ) - ## COMPLETION CALL - if "stream" in optional_params and optional_params["stream"] == True: - response = requests.post( - completion_url, - headers=headers, - data=json.dumps(data), - stream=optional_params["stream"] - ) - return response.iter_lines() - else: - response = requests.post( - completion_url, - headers=headers, - data=json.dumps(data) - ) + def validate_environment(self, api_key, headers): + default_headers = { + "content-type": "application/json", + } + if api_key and headers is None: + default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens + headers = default_headers + elif headers: + headers=headers + else: + headers = default_headers + return headers - - ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten) - is_streamed = False - if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream": - is_streamed = True - - # iterate over the complete streamed response, and return the final answer - if is_streamed: - streamed_response = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="huggingface", logging_obj=logging_obj) - content = "" - for chunk in streamed_response: - content += chunk["choices"][0]["delta"]["content"] - completion_response: List[Dict[str, Any]] = [{"generated_text": content}] - ## LOGGING - logging_obj.post_call( - input=input_text, - api_key=api_key, - original_response=completion_response, - additional_args={"complete_input_dict": data, "task": task}, - ) - else: - ## LOGGING - logging_obj.post_call( - input=input_text, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data, "task": task}, - ) - ## RESPONSE OBJECT - try: - completion_response = response.json() - except: - import traceback - raise HuggingfaceError( - message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}", status_code=response.status_code - ) - print_verbose(f"response: {completion_response}") - if isinstance(completion_response, dict) and "error" in completion_response: - print_verbose(f"completion error: {completion_response['error']}") - print_verbose(f"response.status_code: {response.status_code}") - raise HuggingfaceError( - message=completion_response["error"], - status_code=response.status_code, - ) - else: - if task == "conversational": - if len(completion_response["generated_text"]) > 0: # type: ignore - model_response["choices"][0]["message"][ - "content" - ] = completion_response["generated_text"] # type: ignore - elif task == "text-generation-inference": - if len(completion_response[0]["generated_text"]) > 0: - model_response["choices"][0]["message"][ - "content" - ] = output_parser(completion_response[0]["generated_text"]) - ## GETTING LOGPROBS + FINISH REASON - if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: - model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] + def convert_to_model_response_object(self, + completion_response, + model_response, + task, + optional_params, + encoding, + input_text, + model): + if task == "conversational": + if len(completion_response["generated_text"]) > 0: # type: ignore + model_response["choices"][0]["message"][ + "content" + ] = completion_response["generated_text"] # type: ignore + elif task == "text-generation-inference": + if len(completion_response[0]["generated_text"]) > 0: + model_response["choices"][0]["message"][ + "content" + ] = output_parser(completion_response[0]["generated_text"]) + ## GETTING LOGPROBS + FINISH REASON + if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: + model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] + sum_logprob = 0 + for token in completion_response[0]["details"]["tokens"]: + if token["logprob"] != None: + sum_logprob += token["logprob"] + model_response["choices"][0]["message"]._logprob = sum_logprob + if "best_of" in optional_params and optional_params["best_of"] > 1: + if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]: + choices_list = [] + for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]): sum_logprob = 0 - for token in completion_response[0]["details"]["tokens"]: + for token in item["tokens"]: if token["logprob"] != None: sum_logprob += token["logprob"] - model_response["choices"][0]["message"]._logprob = sum_logprob - if "best_of" in optional_params and optional_params["best_of"] > 1: - if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]: - choices_list = [] - for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]): - sum_logprob = 0 - for token in item["tokens"]: - if token["logprob"] != None: - sum_logprob += token["logprob"] - if len(item["generated_text"]) > 0: - message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob) - else: - message_obj = Message(content=None) - choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) - choices_list.append(choice_obj) - model_response["choices"].extend(choices_list) - else: - if len(completion_response[0]["generated_text"]) > 0: - model_response["choices"][0]["message"][ - "content" - ] = output_parser(completion_response[0]["generated_text"]) - ## CALCULATING USAGE - prompt_tokens = 0 + if len(item["generated_text"]) > 0: + message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob) + else: + message_obj = Message(content=None) + choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) + choices_list.append(choice_obj) + model_response["choices"].extend(choices_list) + else: + if len(completion_response[0]["generated_text"]) > 0: + model_response["choices"][0]["message"][ + "content" + ] = output_parser(completion_response[0]["generated_text"]) + ## CALCULATING USAGE + prompt_tokens = 0 + try: + prompt_tokens = len( + encoding.encode(input_text) + ) ##[TODO] use the llama2 tokenizer here + except: + # this should remain non blocking we should not block a response returning if calculating usage fails + pass + output_text = model_response["choices"][0]["message"].get("content", "") + if output_text is not None and len(output_text) > 0: + completion_tokens = 0 try: - prompt_tokens = len( - encoding.encode(input_text) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) ##[TODO] use the llama2 tokenizer here except: # this should remain non blocking we should not block a response returning if calculating usage fails pass - print_verbose(f'output: {model_response["choices"][0]["message"]}') - output_text = model_response["choices"][0]["message"].get("content", "") - if output_text is not None and len(output_text) > 0: - completion_tokens = 0 - try: - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) - ) ##[TODO] use the llama2 tokenizer here - except: - # this should remain non blocking we should not block a response returning if calculating usage fails - pass - else: - completion_tokens = 0 - - model_response["created"] = time.time() - model_response["model"] = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens - ) - model_response.usage = usage - model_response._hidden_params["original_response"] = completion_response - return model_response - except HuggingfaceError as e: - exception_mapping_worked = True - raise e - except Exception as e: - if exception_mapping_worked: - raise e else: - import traceback - raise HuggingfaceError(status_code=500, message=traceback.format_exc()) + completion_tokens = 0 - -def embedding( - model: str, - input: list, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - logging_obj=None, - model_response=None, - encoding=None, -): - headers = validate_environment(api_key, headers=None) - # print_verbose(f"{model}, {task}") - embed_url = "" - if "https" in model: - embed_url = model - elif api_base: - embed_url = api_base - elif "HF_API_BASE" in os.environ: - embed_url = os.getenv("HF_API_BASE", "") - elif "HUGGINGFACE_API_BASE" in os.environ: - embed_url = os.getenv("HUGGINGFACE_API_BASE", "") - else: - embed_url = f"https://api-inference.huggingface.co/models/{model}" - - if "sentence-transformers" in model: - if len(input) == 0: - raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences") - data = { - "inputs": { - "source_sentence": input[0], - "sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] - } - } - else: - data = { - "inputs": input # type: ignore - } - - ## LOGGING - logging_obj.pre_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, + model_response["created"] = time.time() + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens ) - ## COMPLETION CALL - response = requests.post( - embed_url, headers=headers, data=json.dumps(data) - ) + model_response.usage = usage + model_response._hidden_params["original_response"] = completion_response + return model_response - - ## LOGGING - logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=response, - ) + def completion(self, + model: str, + messages: list, + api_base: Optional[str], + headers: Optional[dict], + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + custom_prompt_dict={}, + acompletion: bool = False, + optional_params=None, + litellm_params=None, + logger_fn=None, + ): + super().completion() + exception_mapping_worked = False + try: + headers = self.validate_environment(api_key, headers) + task = get_hf_task_for_model(model) + print_verbose(f"{model}, {task}") + completion_url = "" + input_text = None + if "https" in model: + completion_url = model + elif api_base: + completion_url = api_base + elif "HF_API_BASE" in os.environ: + completion_url = os.getenv("HF_API_BASE", "") + elif "HUGGINGFACE_API_BASE" in os.environ: + completion_url = os.getenv("HUGGINGFACE_API_BASE", "") + else: + completion_url = f"https://api-inference.huggingface.co/models/{model}" + ## Load Config + config=litellm.HuggingfaceConfig.get_config() + for k, v in config.items(): + if k not in optional_params: # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v - embeddings = response.json() - - if "error" in embeddings: - raise HuggingfaceError(status_code=500, message=embeddings['error']) - - output_data = [] - if "similarities" in embeddings: - for idx, embedding in embeddings["similarities"]: - output_data.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding # flatten list returned from hf - } - ) - else: - for idx, embedding in enumerate(embeddings): - if isinstance(embedding, float): - output_data.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding # flatten list returned from hf - } + ### MAP INPUT PARAMS + if task == "conversational": + inference_params = copy.deepcopy(optional_params) + inference_params.pop("details") + inference_params.pop("return_full_text") + past_user_inputs = [] + generated_responses = [] + text = "" + for message in messages: + if message["role"] == "user": + if text != "": + past_user_inputs.append(text) + text = message["content"] + elif message["role"] == "assistant" or message["role"] == "system": + generated_responses.append(message["content"]) + data = { + "inputs": { + "text": text, + "past_user_inputs": past_user_inputs, + "generated_responses": generated_responses + }, + "parameters": inference_params + } + input_text = "".join(message["content"] for message in messages) + elif task == "text-generation-inference": + # always send "details" and "return_full_text" as params + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages + ) + else: + prompt = prompt_factory(model=model, messages=messages) + data = { + "inputs": prompt, + "parameters": optional_params, + "stream": True if "stream" in optional_params and optional_params["stream"] == True else False, + } + input_text = prompt + else: + # Non TGI and Conversational llms + # We need this branch, it removes 'details' and 'return_full_text' from params + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", {}), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + bos_token=model_prompt_details.get("bos_token", ""), + eos_token=model_prompt_details.get("eos_token", ""), + messages=messages, + ) + else: + prompt = prompt_factory(model=model, messages=messages) + inference_params = copy.deepcopy(optional_params) + inference_params.pop("details") + inference_params.pop("return_full_text") + data = { + "inputs": prompt, + "parameters": inference_params, + "stream": True if "stream" in optional_params and optional_params["stream"] == True else False, + } + input_text = prompt + ## LOGGING + logging_obj.pre_call( + input=input_text, + api_key=api_key, + additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url, "acompletion": acompletion}, ) + ## COMPLETION CALL + if acompletion is True: + ### ASYNC STREAMING + if optional_params.get("stream", False): + return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) + else: + ### ASYNC COMPLETION + return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) + ### SYNC STREAMING + if "stream" in optional_params and optional_params["stream"] == True: + response = requests.post( + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"] + ) + return response.iter_lines() + ### SYNC COMPLETION + else: + response = requests.post( + completion_url, + headers=headers, + data=json.dumps(data) + ) + + ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten) + is_streamed = False + if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream": + is_streamed = True + + # iterate over the complete streamed response, and return the final answer + if is_streamed: + streamed_response = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="huggingface", logging_obj=logging_obj) + content = "" + for chunk in streamed_response: + content += chunk["choices"][0]["delta"]["content"] + completion_response: List[Dict[str, Any]] = [{"generated_text": content}] + ## LOGGING + logging_obj.post_call( + input=input_text, + api_key=api_key, + original_response=completion_response, + additional_args={"complete_input_dict": data, "task": task}, + ) + else: + ## LOGGING + logging_obj.post_call( + input=input_text, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data, "task": task}, + ) + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + import traceback + raise HuggingfaceError( + message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}", status_code=response.status_code + ) + print_verbose(f"response: {completion_response}") + if isinstance(completion_response, dict) and "error" in completion_response: + print_verbose(f"completion error: {completion_response['error']}") + print_verbose(f"response.status_code: {response.status_code}") + raise HuggingfaceError( + message=completion_response["error"], + status_code=response.status_code, + ) + + return self.convert_to_model_response_object( + completion_response=completion_response, + model_response=model_response, + task=task, + optional_params=optional_params, + encoding=encoding, + input_text=input_text, + model=model + ) + except HuggingfaceError as e: + exception_mapping_worked = True + raise e + except Exception as e: + if exception_mapping_worked: + raise e else: - output_data.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding[0][0] # flatten list returned from hf - } - ) - model_response["object"] = "list" - model_response["data"] = output_data - model_response["model"] = model - input_tokens = 0 - for text in input: - input_tokens+=len(encoding.encode(text)) + import traceback + raise HuggingfaceError(status_code=500, message=traceback.format_exc()) - model_response["usage"] = { - "prompt_tokens": input_tokens, - "total_tokens": input_tokens, - } - return model_response + async def acompletion(self, + api_base: str, + data: dict, + headers: dict, + model_response: ModelResponse, + task: str, + encoding: Any, + input_text: str, + model: str, + optional_params: dict): + if self._aclient_session is None: + self._aclient_session = self.create_aclient_session() + client = self._aclient_session + try: + response = await client.post(url=api_base, json=data, headers=headers) + response_json = response.json() + if response.status_code != 200: + raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response) + + ## RESPONSE OBJECT + return self.convert_to_model_response_object(completion_response=response_json, + model_response=model_response, + task=task, + encoding=encoding, + input_text=input_text, + model=model, + optional_params=optional_params) + except Exception as e: + if isinstance(e,httpx.TimeoutException): + raise HuggingfaceError(status_code=500, message="Request Timeout Error") + elif response and hasattr(response, "text"): + raise HuggingfaceError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") + else: + raise HuggingfaceError(status_code=500, message=f"{str(e)}") + + async def async_streaming(self, + logging_obj, + api_base: str, + data: dict, + headers: dict, + model_response: ModelResponse, + model: str): + if self._aclient_session is None: + self._aclient_session = self.create_aclient_session() + client = self._aclient_session + async with client.stream( + url=f"{api_base}", + json=data, + headers=headers, + method="POST" + ) as response: + if response.status_code != 200: + raise HuggingfaceError(status_code=response.status_code, message="An error occurred while streaming") + + streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj) + async for transformed_chunk in streamwrapper: + yield transformed_chunk + + def embedding(self, + model: str, + input: list, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + logging_obj=None, + model_response=None, + encoding=None, + ): + super().embedding() + headers = self.validate_environment(api_key, headers=None) + # print_verbose(f"{model}, {task}") + embed_url = "" + if "https" in model: + embed_url = model + elif api_base: + embed_url = api_base + elif "HF_API_BASE" in os.environ: + embed_url = os.getenv("HF_API_BASE", "") + elif "HUGGINGFACE_API_BASE" in os.environ: + embed_url = os.getenv("HUGGINGFACE_API_BASE", "") + else: + embed_url = f"https://api-inference.huggingface.co/models/{model}" + + if "sentence-transformers" in model: + if len(input) == 0: + raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences") + data = { + "inputs": { + "source_sentence": input[0], + "sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] + } + } + else: + data = { + "inputs": input # type: ignore + } + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = requests.post( + embed_url, headers=headers, data=json.dumps(data) + ) + + + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + + + embeddings = response.json() + + if "error" in embeddings: + raise HuggingfaceError(status_code=500, message=embeddings['error']) + + output_data = [] + if "similarities" in embeddings: + for idx, embedding in embeddings["similarities"]: + output_data.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding # flatten list returned from hf + } + ) + else: + for idx, embedding in enumerate(embeddings): + if isinstance(embedding, float): + output_data.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding # flatten list returned from hf + } + ) + else: + output_data.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding[0][0] # flatten list returned from hf + } + ) + model_response["object"] = "list" + model_response["data"] = output_data + model_response["model"] = model + input_tokens = 0 + for text in input: + input_tokens+=len(encoding.encode(text)) + + model_response["usage"] = { + "prompt_tokens": input_tokens, + "total_tokens": input_tokens, + } + return model_response diff --git a/litellm/main.py b/litellm/main.py index 58a909fa7..fb3675677 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -53,6 +53,7 @@ from .llms import ( maritalk) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion +from .llms.huggingface_restapi import Huggingface from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt import tiktoken from concurrent.futures import ThreadPoolExecutor @@ -77,6 +78,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv openai_chat_completions = OpenAIChatCompletion() openai_text_completions = OpenAITextCompletion() azure_chat_completions = AzureChatCompletion() +huggingface = Huggingface() ####### COMPLETION ENDPOINTS ################ class LiteLLM: @@ -165,7 +167,8 @@ async def acompletion(*args, **kwargs): if (custom_llm_provider == "openai" or custom_llm_provider == "azure" or custom_llm_provider == "custom_openai" - or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all. + or custom_llm_provider == "text-completion-openai" + or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all. if kwargs.get("stream", False): response = completion(*args, **kwargs) else: @@ -862,7 +865,7 @@ def completion( custom_prompt_dict or litellm.custom_prompt_dict ) - model_response = huggingface_restapi.completion( + model_response = huggingface.completion( model=model, messages=messages, api_base=api_base, # type: ignore @@ -874,10 +877,11 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=huggingface_key, + acompletion=acompletion, logging_obj=logging, custom_prompt_dict=custom_prompt_dict ) - if "stream" in optional_params and optional_params["stream"] == True: + if "stream" in optional_params and optional_params["stream"] == True and acompletion is False: # don't try to access stream object, response = CustomStreamWrapper( model_response, model, custom_llm_provider="huggingface", logging_obj=logging diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 79fcdb4a7..6081956f2 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -25,11 +25,12 @@ def test_sync_response(): def test_async_response(): import asyncio + litellm.set_verbose = True async def test_get_response(): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: - response = await acompletion(model="command-nightly", messages=messages) + response = await acompletion(model="huggingface/HuggingFaceH4/zephyr-7b-beta", messages=messages) print(f"response: {response}") except Exception as e: pytest.fail(f"An exception occurred: {e}") @@ -44,7 +45,7 @@ def test_get_response_streaming(): messages = [{"content": user_message, "role": "user"}] try: litellm.set_verbose = True - response = await acompletion(model="command-nightly", messages=messages, stream=True) + response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True) print(type(response)) import inspect @@ -67,15 +68,16 @@ def test_get_response_streaming(): asyncio.run(test_async_call()) -test_get_response_streaming() +# test_get_response_streaming() def test_get_response_non_openai_streaming(): import asyncio + litellm.set_verbose = True async def test_async_call(): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: - response = await acompletion(model="command-nightly", messages=messages, stream=True) + response = await acompletion(model="huggingface/HuggingFaceH4/zephyr-7b-beta", messages=messages, stream=True) print(type(response)) import inspect @@ -98,4 +100,4 @@ def test_get_response_non_openai_streaming(): return response asyncio.run(test_async_call()) -# test_get_response_non_openai_streaming() +test_get_response_non_openai_streaming() diff --git a/litellm/utils.py b/litellm/utils.py index 4373ebe1c..a4d1016a1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -511,6 +511,8 @@ class Logging: masked_headers = {k: v[:-40] + '*' * 40 if len(v) > 40 else v for k, v in headers.items()} formatted_headers = " ".join([f"-H '{k}: {v}'" for k, v in masked_headers.items()]) + print_verbose(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}") + curl_command = "\n\nPOST Request Sent from LiteLLM:\n" curl_command += "curl -X POST \\\n" curl_command += f"{api_base} \\\n" @@ -4313,7 +4315,6 @@ class CustomStreamWrapper: def handle_huggingface_chunk(self, chunk): try: - chunk = chunk.decode("utf-8") text = "" is_finished = False finish_reason = "" @@ -4770,7 +4771,8 @@ class CustomStreamWrapper: if (self.custom_llm_provider == "openai" or self.custom_llm_provider == "azure" or self.custom_llm_provider == "custom_openai" - or self.custom_llm_provider == "text-completion-openai"): + or self.custom_llm_provider == "text-completion-openai" + or self.custom_llm_provider == "huggingface"): async for chunk in self.completion_stream: if chunk == "None" or chunk is None: raise Exception