diff --git a/docs/my-website/docs/completion/huggingface_tutorial.md b/docs/my-website/docs/completion/huggingface_tutorial.md
new file mode 100644
index 0000000000..ea822dec56
--- /dev/null
+++ b/docs/my-website/docs/completion/huggingface_tutorial.md
@@ -0,0 +1,45 @@
+# Llama2 - Huggingface Tutorial
+[Huggingface](https://huggingface.co/) is an open source platform to deploy machine-learnings models.
+
+## Call Llama2 with Huggingface Inference Endpoints
+LiteLLM makes it easy to call your public, private or the default huggingface endpoints.
+
+In this case, let's try and call 3 models:
+- `deepset/deberta-v3-large-squad2`: calls the default huggingface endpoint
+- `meta-llama/Llama-2-7b-hf`: calls a public endpoint
+- `meta-llama/Llama-2-7b-chat-hf`: call your privat endpoint
+
+### Case 1: Call default huggingface endpoint
+
+Here's the complete example:
+
+```
+from litellm import completion
+
+model = "deepset/deberta-v3-large-squad2"
+messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format
+
+### CALLING ENDPOINT
+completion(model=model, messages=messages, custom_llm_provider="huggingface")
+```
+
+What's happening?
+- model - this is the name of the deployed model on huggingface
+- messages - this is the input. We accept the OpenAI chat format. For huggingface, by default we iterate through the list and add the message["content"] to the prompt.
+
+### Case 2: Call Llama2 public endpoint
+
+We've deployed `meta-llama/Llama-2-7b-hf` behind a public endpoint - `https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud`.
+
+Let's try it out:
+```
+from litellm import completion
+
+model = "meta-llama/Llama-2-7b-hf"
+messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format
+custom_api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud"
+
+### CALLING ENDPOINT
+completion(model=model, messages=messages, custom_llm_provider="huggingface", custom_api_base=custom_api_base)
+```
+
diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js
index d744f1490a..b85a8f7eff 100644
--- a/docs/my-website/sidebars.js
+++ b/docs/my-website/sidebars.js
@@ -22,7 +22,7 @@ const sidebars = {
{
type: 'category',
label: 'completion_function',
- items: ['completion/input', 'completion/supported','completion/output'],
+ items: ['completion/input', 'completion/supported','completion/output', 'completion/huggingface_tutorial'],
},
{
type: 'category',
diff --git a/litellm/__init__.py b/litellm/__init__.py
index 017e7f3515..026afcf142 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -11,6 +11,7 @@ anthropic_key = None
replicate_key = None
cohere_key = None
openrouter_key = None
+huggingface_key = None
vertex_project = None
vertex_location = None
@@ -62,9 +63,6 @@ open_ai_chat_completion_models = [
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
- 'gpt-3.5-turbo',
- 'gpt-3.5-turbo-16k-0613',
- 'gpt-3.5-turbo-16k'
]
open_ai_text_completion_models = [
'text-davinci-003'
@@ -111,7 +109,22 @@ vertex_text_models = [
"text-bison@001"
]
-model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_chat_models + vertex_text_models
+huggingface_models = [
+ "meta-llama/Llama-2-7b-hf",
+ "meta-llama/Llama-2-7b-chat-hf",
+ "meta-llama/Llama-2-13b-hf",
+ "meta-llama/Llama-2-13b-chat-hf",
+ "meta-llama/Llama-2-70b-hf",
+ "meta-llama/Llama-2-70b-chat-hf",
+ "meta-llama/Llama-2-7b",
+ "meta-llama/Llama-2-7b-chat",
+ "meta-llama/Llama-2-13b",
+ "meta-llama/Llama-2-13b-chat",
+ "meta-llama/Llama-2-70b",
+ "meta-llama/Llama-2-70b-chat",
+] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/completion/supported
+
+model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + huggingface_models + vertex_chat_models + vertex_text_models
####### EMBEDDING MODELS ###################
open_ai_embedding_models = [
diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc
index 227d66ada6..3a9b2741d5 100644
Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ
diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc
index fd19706592..afc8469fd9 100644
Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ
diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc
index 451db73573..92c3af615d 100644
Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ
diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py
index 5b5d928b20..c617a0ae0f 100644
--- a/litellm/llms/anthropic.py
+++ b/litellm/llms/anthropic.py
@@ -13,6 +13,7 @@ class AnthropicError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
+ super().__init__(self.message) # Call the base class constructor with the parameters it needs
class AnthropicLLM:
@@ -75,7 +76,6 @@ class AnthropicLLM:
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
- print(f"completion_response: {completion_response}")
if "error" in completion_response:
raise AnthropicError(message=completion_response["error"], status_code=response.status_code)
else:
diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py
new file mode 100644
index 0000000000..7b88d03e65
--- /dev/null
+++ b/litellm/llms/huggingface_restapi.py
@@ -0,0 +1,94 @@
+## Uses the huggingface text generation inference API
+import os, json
+from enum import Enum
+import requests
+from litellm import logging
+import time
+from typing import Callable
+
+class HuggingfaceError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ super().__init__(self.message) # Call the base class constructor with the parameters it needs
+
+class HuggingfaceRestAPILLM():
+ def __init__(self, encoding, api_key=None) -> None:
+ self.encoding = encoding
+ self.validate_environment(api_key=api_key)
+
+ def validate_environment(self, api_key): # set up the environment required to run the model
+ self.headers = {
+ "content-type": "application/json",
+ }
+ # get the api key if it exists in the environment or is passed in, but don't require it
+ self.api_key = os.getenv("HF_TOKEN") if "HF_TOKEN" in os.environ else api_key
+ if self.api_key != None:
+ self.headers["Authorization"] = f"Bearer {self.api_key}"
+
+ def completion(self, model: str, messages: list, custom_api_base: str, model_response: dict, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls
+ if custom_api_base:
+ completion_url = custom_api_base
+ elif "HF_API_BASE" in os.environ:
+ completion_url = os.getenv("HF_API_BASE")
+ else:
+ completion_url = f"https://api-inference.huggingface.co/models/{model}"
+ prompt = ""
+ if "meta-llama" in model and "chat" in model: # use the required special tokens for meta-llama - https://huggingface.co/blog/llama2#how-to-prompt-llama-2
+ prompt = ""
+ for message in messages:
+ if message["role"] == "system":
+ prompt += "[INST] <>" + message["content"]
+ elif message["role"] == "assistant":
+ prompt += message["content"] + "[INST]"
+ elif message["role"] == "user":
+ prompt += message["content"] + "[/INST]"
+ else:
+ for message in messages:
+ prompt += f"{message['content']}"
+ ### MAP INPUT PARAMS
+ # max tokens
+ if "max_tokens" in optional_params:
+ value = optional_params.pop("max_tokens")
+ optional_params["max_new_tokens"] = value
+ data = {
+ "inputs": prompt,
+ # "parameters": optional_params
+ }
+ ## LOGGING
+ logging(model=model, input=prompt, additional_args={"litellm_params": litellm_params, "optional_params": optional_params}, logger_fn=logger_fn)
+ ## COMPLETION CALL
+ response = requests.post(completion_url, headers=self.headers, data=json.dumps(data))
+ if "stream" in optional_params and optional_params["stream"] == True:
+ return response.iter_lines()
+ else:
+ ## LOGGING
+ logging(model=model, input=prompt, additional_args={"litellm_params": litellm_params, "optional_params": optional_params, "original_response": response.text}, logger_fn=logger_fn)
+ print_verbose(f"raw model_response: {response.text}")
+ ## RESPONSE OBJECT
+ completion_response = response.json()
+ print(f"response: {completion_response}")
+ if isinstance(completion_response, dict) and "error" in completion_response:
+ print(f"completion error: {completion_response['error']}")
+ print(f"response.status_code: {response.status_code}")
+ raise HuggingfaceError(message=completion_response["error"], status_code=response.status_code)
+ else:
+ model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
+
+ ## CALCULATING USAGE
+ prompt_tokens = len(self.encoding.encode(prompt)) ##[TODO] use the llama2 tokenizer here
+ completion_tokens = len(self.encoding.encode(model_response["choices"][0]["message"]["content"])) ##[TODO] use the llama2 tokenizer here
+
+
+ model_response["created"] = time.time()
+ model_response["model"] = model
+ model_response["usage"] = {
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": completion_tokens,
+ "total_tokens": prompt_tokens + completion_tokens
+ }
+ return model_response
+ pass
+
+ def embedding(): # logic for parsing in - calling - parsing out model embedding calls
+ pass
\ No newline at end of file
diff --git a/litellm/main.py b/litellm/main.py
index 3abdccddf3..d17bebd0d1 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -7,6 +7,7 @@ import litellm
from litellm import client, logging, exception_type, timeout, get_optional_params, get_litellm_params
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args
from .llms.anthropic import AnthropicLLM
+from .llms.huggingface_restapi import HuggingfaceRestAPILLM
import tiktoken
from concurrent.futures import ThreadPoolExecutor
encoding = tiktoken.get_encoding("cl100k_base")
@@ -222,7 +223,6 @@ def completion(
response = CustomStreamWrapper(model_response, model)
return response
response = model_response
-
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
openai.api_type = "openai"
# not sure if this will work after someone first uses another API
@@ -305,37 +305,15 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
- elif custom_llm_provider == "huggingface":
- import requests
- API_URL = f"https://api-inference.huggingface.co/models/{model}"
- HF_TOKEN = get_secret("HF_TOKEN")
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
-
- prompt = " ".join([message["content"] for message in messages])
- ## LOGGING
- logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
- input_payload = {"inputs": prompt}
- response = requests.post(API_URL, headers=headers, json=input_payload)
- ## LOGGING
- logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": response.text}, logger_fn=logger_fn)
- if isinstance(response, dict) and "error" in response:
- raise Exception(response["error"])
- json_response = response.json()
- if 'error' in json_response: # raise HF errors when they exist
- raise Exception(json_response['error'])
-
- completion_response = json_response[0]['generated_text']
- prompt_tokens = len(encoding.encode(prompt))
- completion_tokens = len(encoding.encode(completion_response))
- ## RESPONSE OBJECT
- model_response["choices"][0]["message"]["content"] = completion_response
- model_response["created"] = time.time()
- model_response["model"] = model
- model_response["usage"] = {
- "prompt_tokens": prompt_tokens,
- "completion_tokens": completion_tokens,
- "total_tokens": prompt_tokens + completion_tokens
- }
+ elif model in litellm.huggingface_models or custom_llm_provider == "huggingface":
+ custom_llm_provider = "huggingface"
+ huggingface_key = api_key if api_key is not None else litellm.huggingface_key
+ huggingface_client = HuggingfaceRestAPILLM(encoding=encoding, api_key=huggingface_key)
+ model_response = huggingface_client.completion(model=model, messages=messages, custom_api_base=custom_api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn)
+ if 'stream' in optional_params and optional_params['stream'] == True:
+ # don't try to access stream object,
+ response = CustomStreamWrapper(model_response, model, custom_llm_provider="huggingface")
+ return response
response = model_response
elif custom_llm_provider == "together_ai":
import requests
@@ -383,7 +361,7 @@ def completion(
prompt = " ".join([message["content"] for message in messages])
## LOGGING
- logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
+ logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"litellm_params": litellm_params, "optional_params": optional_params}, logger_fn=logger_fn)
chat_model = ChatModel.from_pretrained(model)
@@ -434,13 +412,13 @@ def completion(
## LOGGING
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
args = locals()
- raise ValueError(f"Invalid completion model args passed in. Check your input - {args}")
+ raise ValueError(f"Unable to map your input to a model. Check your input - {args}")
return response
except Exception as e:
## LOGGING
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens}, logger_fn=logger_fn, exception=e)
## Map to OpenAI Exception
- raise exception_type(model=model, original_exception=e)
+ raise exception_type(model=model, custom_llm_provider=custom_llm_provider, original_exception=e)
def batch_completion(*args, **kwargs):
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py
index c64c845362..f639e327db 100644
--- a/litellm/tests/test_completion.py
+++ b/litellm/tests/test_completion.py
@@ -49,6 +49,17 @@ def test_completion_hf_api():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
+def test_completion_hf_deployed_api():
+ try:
+ user_message = "There's a llama in my garden 😱 What should I do?"
+ messages = [{ "content": user_message,"role": "user"}]
+ response = completion(model="meta-llama/Llama-2-7b-chat-hf", messages=messages, custom_llm_provider="huggingface", custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn)
+ # Add any assertions here to check the response
+ print(response)
+ except Exception as e:
+ pytest.fail(f"Error occurred: {e}")
+
+test_completion_hf_deployed_api()
def test_completion_cohere():
try:
response = completion(model="command-nightly", messages=messages, max_tokens=500)
diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py
index b7332772fb..317dea904b 100644
--- a/litellm/tests/test_streaming.py
+++ b/litellm/tests/test_streaming.py
@@ -23,6 +23,17 @@ try:
for chunk in response:
print(chunk['choices'][0]['delta'])
score +=1
+except:
+ print(f"error occurred: {traceback.format_exc()}")
+ pass
+
+
+# test on anthropic completion call
+try:
+ response = completion(model="meta-llama/Llama-2-7b-chat-hf", messages=messages, custom_llm_provider="huggingface", custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", stream=True, logger_fn=logger_fn)
+ for chunk in response:
+ print(chunk['choices'][0]['delta'])
+ score +=1
except:
print(f"error occurred: {traceback.format_exc()}")
pass
\ No newline at end of file
diff --git a/litellm/utils.py b/litellm/utils.py
index 6fc85aaa10..f57c390cbe 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -589,7 +589,7 @@ def modify_integration(integration_name, integration_params):
if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"]
-def exception_type(model, original_exception):
+def exception_type(model, original_exception, custom_llm_provider):
global user_logger_fn
exception_mapping_worked = False
try:
@@ -640,6 +640,17 @@ def exception_type(model, original_exception):
elif "CohereConnectionError" in exception_type: # cohere seems to fire these errors when we load test it (1k+ messages / min)
exception_mapping_worked = True
raise RateLimitError(f"CohereException - {original_exception.message}")
+ elif custom_llm_provider == "huggingface":
+ if hasattr(original_exception, "status_code"):
+ if original_exception.status_code == 401:
+ exception_mapping_worked = True
+ raise AuthenticationError(f"HuggingfaceException - {original_exception.message}")
+ elif original_exception.status_code == 400:
+ exception_mapping_worked = True
+ raise InvalidRequestError(f"HuggingfaceException - {original_exception.message}", f"{model}")
+ elif original_exception.status_code == 429:
+ exception_mapping_worked = True
+ raise RateLimitError(f"HuggingfaceException - {original_exception.message}")
raise original_exception # base case - return the original exception
else:
raise original_exception
@@ -715,8 +726,9 @@ def get_secret(secret_name):
# wraps the completion stream to return the correct format for the model
# replicate/anthropic/cohere
class CustomStreamWrapper:
- def __init__(self, completion_stream, model):
+ def __init__(self, completion_stream, model, custom_llm_provider=None):
self.model = model
+ self.custom_llm_provider = custom_llm_provider
if model in litellm.cohere_models:
# cohere does not return an iterator, so we need to wrap it in one
self.completion_stream = iter(completion_stream)
@@ -745,6 +757,16 @@ class CustomStreamWrapper:
return extracted_text
else:
return ""
+
+ def handle_huggingface_chunk(self, chunk):
+ chunk = chunk.decode("utf-8")
+ if chunk.startswith('data:'):
+ data_json = json.loads(chunk[5:])
+ if "token" in data_json and "text" in data_json["token"]:
+ return data_json["token"]["text"]
+ else:
+ return ""
+ return ""
def __next__(self):
completion_obj ={ "role": "assistant", "content": ""}
@@ -763,6 +785,9 @@ class CustomStreamWrapper:
elif self.model in litellm.cohere_models:
chunk = next(self.completion_stream)
completion_obj["content"] = chunk.text
+ elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
+ chunk = next(self.completion_stream)
+ completion_obj["content"] = self.handle_huggingface_chunk(chunk)
# return this for all models
return {"choices": [{"delta": completion_obj}]}