bump: version 0.8.4 → 0.8.5

This commit is contained in:
Krrish Dholakia 2023-10-14 16:43:06 -07:00
parent 80c60e71c1
commit 7358d2e4ea
11 changed files with 228 additions and 7343 deletions

View file

@ -6,7 +6,7 @@ import requests
import time
import litellm
from typing import Callable
from litellm.utils import ModelResponse, Choices, Message
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
@ -65,12 +65,17 @@ class HuggingfaceConfig():
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
def validate_environment(api_key):
headers = {
def validate_environment(api_key, headers):
default_headers = {
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if api_key and headers is None:
default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = default_headers
elif headers:
headers=headers
else:
headers = default_headers
return headers
tgi_models_cache = None
@ -125,6 +130,7 @@ def completion(
model: str,
messages: list,
api_base: Optional[str],
headers: Optional[dict],
model_response: ModelResponse,
print_verbose: Callable,
encoding,
@ -135,7 +141,8 @@ def completion(
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
print(f'headers inside hf rest api: {headers}')
headers = validate_environment(api_key, headers)
task = get_hf_task_for_model(model)
print_verbose(f"{model}, {task}")
completion_url = ""
@ -227,7 +234,7 @@ def completion(
logging_obj.pre_call(
input=input_text,
api_key=api_key,
additional_args={"complete_input_dict": data, "task": task},
additional_args={"complete_input_dict": data, "task": task, "headers": headers},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
@ -244,20 +251,43 @@ def completion(
headers=headers,
data=json.dumps(data)
)
## LOGGING
logging_obj.post_call(
input=input_text,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data, "task": task},
)
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise HuggingfaceError(
message=response.text, status_code=response.status_code
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
is_streamed = False
print(f"response keys: {response.__dict__.keys()}")
print(f"response keys: {response.__dict__['headers']}")
if response.__dict__['headers']["Content-Type"] == "text/event-stream":
is_streamed = True
# iterate over the complete streamed response, and return the final answer
if is_streamed:
streamed_response = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="huggingface", logging_obj=logging_obj)
content = ""
for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"]
completion_response = [{"generated_text": content}]
## LOGGING
logging_obj.post_call(
input=input_text,
api_key=api_key,
original_response=completion_response,
additional_args={"complete_input_dict": data, "task": task},
)
else:
## LOGGING
logging_obj.post_call(
input=input_text,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data, "task": task},
)
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise HuggingfaceError(
message=response.text, status_code=response.status_code
)
print_verbose(f"response: {completion_response}")
if isinstance(completion_response, dict) and "error" in completion_response:
print_verbose(f"completion error: {completion_response['error']}")