mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
adding support for meta-llama-2
This commit is contained in:
parent
b5875cc4bd
commit
6aff47083b
12 changed files with 220 additions and 43 deletions
45
docs/my-website/docs/completion/huggingface_tutorial.md
Normal file
45
docs/my-website/docs/completion/huggingface_tutorial.md
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
# Llama2 - Huggingface Tutorial
|
||||||
|
[Huggingface](https://huggingface.co/) is an open source platform to deploy machine-learnings models.
|
||||||
|
|
||||||
|
## Call Llama2 with Huggingface Inference Endpoints
|
||||||
|
LiteLLM makes it easy to call your public, private or the default huggingface endpoints.
|
||||||
|
|
||||||
|
In this case, let's try and call 3 models:
|
||||||
|
- `deepset/deberta-v3-large-squad2`: calls the default huggingface endpoint
|
||||||
|
- `meta-llama/Llama-2-7b-hf`: calls a public endpoint
|
||||||
|
- `meta-llama/Llama-2-7b-chat-hf`: call your privat endpoint
|
||||||
|
|
||||||
|
### Case 1: Call default huggingface endpoint
|
||||||
|
|
||||||
|
Here's the complete example:
|
||||||
|
|
||||||
|
```
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
model = "deepset/deberta-v3-large-squad2"
|
||||||
|
messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format
|
||||||
|
|
||||||
|
### CALLING ENDPOINT
|
||||||
|
completion(model=model, messages=messages, custom_llm_provider="huggingface")
|
||||||
|
```
|
||||||
|
|
||||||
|
What's happening?
|
||||||
|
- model - this is the name of the deployed model on huggingface
|
||||||
|
- messages - this is the input. We accept the OpenAI chat format. For huggingface, by default we iterate through the list and add the message["content"] to the prompt.
|
||||||
|
|
||||||
|
### Case 2: Call Llama2 public endpoint
|
||||||
|
|
||||||
|
We've deployed `meta-llama/Llama-2-7b-hf` behind a public endpoint - `https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud`.
|
||||||
|
|
||||||
|
Let's try it out:
|
||||||
|
```
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
model = "meta-llama/Llama-2-7b-hf"
|
||||||
|
messages = [{"role": "user", "content": "Hey, how's it going?"}] # LiteLLM follows the OpenAI format
|
||||||
|
custom_api_base = "https://ag3dkq4zui5nu8g3.us-east-1.aws.endpoints.huggingface.cloud"
|
||||||
|
|
||||||
|
### CALLING ENDPOINT
|
||||||
|
completion(model=model, messages=messages, custom_llm_provider="huggingface", custom_api_base=custom_api_base)
|
||||||
|
```
|
||||||
|
|
|
@ -22,7 +22,7 @@ const sidebars = {
|
||||||
{
|
{
|
||||||
type: 'category',
|
type: 'category',
|
||||||
label: 'completion_function',
|
label: 'completion_function',
|
||||||
items: ['completion/input', 'completion/supported','completion/output'],
|
items: ['completion/input', 'completion/supported','completion/output', 'completion/huggingface_tutorial'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: 'category',
|
type: 'category',
|
||||||
|
|
|
@ -11,6 +11,7 @@ anthropic_key = None
|
||||||
replicate_key = None
|
replicate_key = None
|
||||||
cohere_key = None
|
cohere_key = None
|
||||||
openrouter_key = None
|
openrouter_key = None
|
||||||
|
huggingface_key = None
|
||||||
vertex_project = None
|
vertex_project = None
|
||||||
vertex_location = None
|
vertex_location = None
|
||||||
|
|
||||||
|
@ -62,9 +63,6 @@ open_ai_chat_completion_models = [
|
||||||
"gpt-3.5-turbo-16k",
|
"gpt-3.5-turbo-16k",
|
||||||
"gpt-3.5-turbo-0613",
|
"gpt-3.5-turbo-0613",
|
||||||
"gpt-3.5-turbo-16k-0613",
|
"gpt-3.5-turbo-16k-0613",
|
||||||
'gpt-3.5-turbo',
|
|
||||||
'gpt-3.5-turbo-16k-0613',
|
|
||||||
'gpt-3.5-turbo-16k'
|
|
||||||
]
|
]
|
||||||
open_ai_text_completion_models = [
|
open_ai_text_completion_models = [
|
||||||
'text-davinci-003'
|
'text-davinci-003'
|
||||||
|
@ -111,7 +109,22 @@ vertex_text_models = [
|
||||||
"text-bison@001"
|
"text-bison@001"
|
||||||
]
|
]
|
||||||
|
|
||||||
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_chat_models + vertex_text_models
|
huggingface_models = [
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
"meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"meta-llama/Llama-2-13b-hf",
|
||||||
|
"meta-llama/Llama-2-13b-chat-hf",
|
||||||
|
"meta-llama/Llama-2-70b-hf",
|
||||||
|
"meta-llama/Llama-2-70b-chat-hf",
|
||||||
|
"meta-llama/Llama-2-7b",
|
||||||
|
"meta-llama/Llama-2-7b-chat",
|
||||||
|
"meta-llama/Llama-2-13b",
|
||||||
|
"meta-llama/Llama-2-13b-chat",
|
||||||
|
"meta-llama/Llama-2-70b",
|
||||||
|
"meta-llama/Llama-2-70b-chat",
|
||||||
|
] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/completion/supported
|
||||||
|
|
||||||
|
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + huggingface_models + vertex_chat_models + vertex_text_models
|
||||||
|
|
||||||
####### EMBEDDING MODELS ###################
|
####### EMBEDDING MODELS ###################
|
||||||
open_ai_embedding_models = [
|
open_ai_embedding_models = [
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -13,6 +13,7 @@ class AnthropicError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
|
super().__init__(self.message) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
class AnthropicLLM:
|
class AnthropicLLM:
|
||||||
|
|
||||||
|
@ -75,7 +76,6 @@ class AnthropicLLM:
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
print(f"completion_response: {completion_response}")
|
|
||||||
if "error" in completion_response:
|
if "error" in completion_response:
|
||||||
raise AnthropicError(message=completion_response["error"], status_code=response.status_code)
|
raise AnthropicError(message=completion_response["error"], status_code=response.status_code)
|
||||||
else:
|
else:
|
||||||
|
|
94
litellm/llms/huggingface_restapi.py
Normal file
94
litellm/llms/huggingface_restapi.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
## Uses the huggingface text generation inference API
|
||||||
|
import os, json
|
||||||
|
from enum import Enum
|
||||||
|
import requests
|
||||||
|
from litellm import logging
|
||||||
|
import time
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
class HuggingfaceError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
class HuggingfaceRestAPILLM():
|
||||||
|
def __init__(self, encoding, api_key=None) -> None:
|
||||||
|
self.encoding = encoding
|
||||||
|
self.validate_environment(api_key=api_key)
|
||||||
|
|
||||||
|
def validate_environment(self, api_key): # set up the environment required to run the model
|
||||||
|
self.headers = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
# get the api key if it exists in the environment or is passed in, but don't require it
|
||||||
|
self.api_key = os.getenv("HF_TOKEN") if "HF_TOKEN" in os.environ else api_key
|
||||||
|
if self.api_key != None:
|
||||||
|
self.headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
def completion(self, model: str, messages: list, custom_api_base: str, model_response: dict, print_verbose: Callable, optional_params=None, litellm_params=None, logger_fn=None): # logic for parsing in - calling - parsing out model completion calls
|
||||||
|
if custom_api_base:
|
||||||
|
completion_url = custom_api_base
|
||||||
|
elif "HF_API_BASE" in os.environ:
|
||||||
|
completion_url = os.getenv("HF_API_BASE")
|
||||||
|
else:
|
||||||
|
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||||
|
prompt = ""
|
||||||
|
if "meta-llama" in model and "chat" in model: # use the required special tokens for meta-llama - https://huggingface.co/blog/llama2#how-to-prompt-llama-2
|
||||||
|
prompt = "<s>"
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
prompt += "[INST] <<SYS>>" + message["content"]
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
prompt += message["content"] + "</s><s>[INST]"
|
||||||
|
elif message["role"] == "user":
|
||||||
|
prompt += message["content"] + "[/INST]"
|
||||||
|
else:
|
||||||
|
for message in messages:
|
||||||
|
prompt += f"{message['content']}"
|
||||||
|
### MAP INPUT PARAMS
|
||||||
|
# max tokens
|
||||||
|
if "max_tokens" in optional_params:
|
||||||
|
value = optional_params.pop("max_tokens")
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
data = {
|
||||||
|
"inputs": prompt,
|
||||||
|
# "parameters": optional_params
|
||||||
|
}
|
||||||
|
## LOGGING
|
||||||
|
logging(model=model, input=prompt, additional_args={"litellm_params": litellm_params, "optional_params": optional_params}, logger_fn=logger_fn)
|
||||||
|
## COMPLETION CALL
|
||||||
|
response = requests.post(completion_url, headers=self.headers, data=json.dumps(data))
|
||||||
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
return response.iter_lines()
|
||||||
|
else:
|
||||||
|
## LOGGING
|
||||||
|
logging(model=model, input=prompt, additional_args={"litellm_params": litellm_params, "optional_params": optional_params, "original_response": response.text}, logger_fn=logger_fn)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
completion_response = response.json()
|
||||||
|
print(f"response: {completion_response}")
|
||||||
|
if isinstance(completion_response, dict) and "error" in completion_response:
|
||||||
|
print(f"completion error: {completion_response['error']}")
|
||||||
|
print(f"response.status_code: {response.status_code}")
|
||||||
|
raise HuggingfaceError(message=completion_response["error"], status_code=response.status_code)
|
||||||
|
else:
|
||||||
|
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
|
||||||
|
|
||||||
|
## CALCULATING USAGE
|
||||||
|
prompt_tokens = len(self.encoding.encode(prompt)) ##[TODO] use the llama2 tokenizer here
|
||||||
|
completion_tokens = len(self.encoding.encode(model_response["choices"][0]["message"]["content"])) ##[TODO] use the llama2 tokenizer here
|
||||||
|
|
||||||
|
|
||||||
|
model_response["created"] = time.time()
|
||||||
|
model_response["model"] = model
|
||||||
|
model_response["usage"] = {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens
|
||||||
|
}
|
||||||
|
return model_response
|
||||||
|
pass
|
||||||
|
|
||||||
|
def embedding(): # logic for parsing in - calling - parsing out model embedding calls
|
||||||
|
pass
|
|
@ -7,6 +7,7 @@ import litellm
|
||||||
from litellm import client, logging, exception_type, timeout, get_optional_params, get_litellm_params
|
from litellm import client, logging, exception_type, timeout, get_optional_params, get_litellm_params
|
||||||
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args
|
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args
|
||||||
from .llms.anthropic import AnthropicLLM
|
from .llms.anthropic import AnthropicLLM
|
||||||
|
from .llms.huggingface_restapi import HuggingfaceRestAPILLM
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
@ -222,7 +223,6 @@ def completion(
|
||||||
response = CustomStreamWrapper(model_response, model)
|
response = CustomStreamWrapper(model_response, model)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
|
|
||||||
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
|
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
|
||||||
openai.api_type = "openai"
|
openai.api_type = "openai"
|
||||||
# not sure if this will work after someone first uses another API
|
# not sure if this will work after someone first uses another API
|
||||||
|
@ -305,37 +305,15 @@ def completion(
|
||||||
"total_tokens": prompt_tokens + completion_tokens
|
"total_tokens": prompt_tokens + completion_tokens
|
||||||
}
|
}
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "huggingface":
|
elif model in litellm.huggingface_models or custom_llm_provider == "huggingface":
|
||||||
import requests
|
custom_llm_provider = "huggingface"
|
||||||
API_URL = f"https://api-inference.huggingface.co/models/{model}"
|
huggingface_key = api_key if api_key is not None else litellm.huggingface_key
|
||||||
HF_TOKEN = get_secret("HF_TOKEN")
|
huggingface_client = HuggingfaceRestAPILLM(encoding=encoding, api_key=huggingface_key)
|
||||||
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
|
model_response = huggingface_client.completion(model=model, messages=messages, custom_api_base=custom_api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn)
|
||||||
|
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||||
prompt = " ".join([message["content"] for message in messages])
|
# don't try to access stream object,
|
||||||
## LOGGING
|
response = CustomStreamWrapper(model_response, model, custom_llm_provider="huggingface")
|
||||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
return response
|
||||||
input_payload = {"inputs": prompt}
|
|
||||||
response = requests.post(API_URL, headers=headers, json=input_payload)
|
|
||||||
## LOGGING
|
|
||||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": response.text}, logger_fn=logger_fn)
|
|
||||||
if isinstance(response, dict) and "error" in response:
|
|
||||||
raise Exception(response["error"])
|
|
||||||
json_response = response.json()
|
|
||||||
if 'error' in json_response: # raise HF errors when they exist
|
|
||||||
raise Exception(json_response['error'])
|
|
||||||
|
|
||||||
completion_response = json_response[0]['generated_text']
|
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
|
||||||
completion_tokens = len(encoding.encode(completion_response))
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
model_response["choices"][0]["message"]["content"] = completion_response
|
|
||||||
model_response["created"] = time.time()
|
|
||||||
model_response["model"] = model
|
|
||||||
model_response["usage"] = {
|
|
||||||
"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": completion_tokens,
|
|
||||||
"total_tokens": prompt_tokens + completion_tokens
|
|
||||||
}
|
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
import requests
|
import requests
|
||||||
|
@ -383,7 +361,7 @@ def completion(
|
||||||
|
|
||||||
prompt = " ".join([message["content"] for message in messages])
|
prompt = " ".join([message["content"] for message in messages])
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"litellm_params": litellm_params, "optional_params": optional_params}, logger_fn=logger_fn)
|
||||||
|
|
||||||
chat_model = ChatModel.from_pretrained(model)
|
chat_model = ChatModel.from_pretrained(model)
|
||||||
|
|
||||||
|
@ -434,13 +412,13 @@ def completion(
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
||||||
args = locals()
|
args = locals()
|
||||||
raise ValueError(f"Invalid completion model args passed in. Check your input - {args}")
|
raise ValueError(f"Unable to map your input to a model. Check your input - {args}")
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens}, logger_fn=logger_fn, exception=e)
|
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens}, logger_fn=logger_fn, exception=e)
|
||||||
## Map to OpenAI Exception
|
## Map to OpenAI Exception
|
||||||
raise exception_type(model=model, original_exception=e)
|
raise exception_type(model=model, custom_llm_provider=custom_llm_provider, original_exception=e)
|
||||||
|
|
||||||
def batch_completion(*args, **kwargs):
|
def batch_completion(*args, **kwargs):
|
||||||
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
|
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
|
||||||
|
|
|
@ -49,6 +49,17 @@ def test_completion_hf_api():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
def test_completion_hf_deployed_api():
|
||||||
|
try:
|
||||||
|
user_message = "There's a llama in my garden 😱 What should I do?"
|
||||||
|
messages = [{ "content": user_message,"role": "user"}]
|
||||||
|
response = completion(model="meta-llama/Llama-2-7b-chat-hf", messages=messages, custom_llm_provider="huggingface", custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", logger_fn=logger_fn)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
test_completion_hf_deployed_api()
|
||||||
def test_completion_cohere():
|
def test_completion_cohere():
|
||||||
try:
|
try:
|
||||||
response = completion(model="command-nightly", messages=messages, max_tokens=500)
|
response = completion(model="command-nightly", messages=messages, max_tokens=500)
|
||||||
|
|
|
@ -26,3 +26,14 @@ try:
|
||||||
except:
|
except:
|
||||||
print(f"error occurred: {traceback.format_exc()}")
|
print(f"error occurred: {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# test on anthropic completion call
|
||||||
|
try:
|
||||||
|
response = completion(model="meta-llama/Llama-2-7b-chat-hf", messages=messages, custom_llm_provider="huggingface", custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", stream=True, logger_fn=logger_fn)
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk['choices'][0]['delta'])
|
||||||
|
score +=1
|
||||||
|
except:
|
||||||
|
print(f"error occurred: {traceback.format_exc()}")
|
||||||
|
pass
|
|
@ -589,7 +589,7 @@ def modify_integration(integration_name, integration_params):
|
||||||
if "table_name" in integration_params:
|
if "table_name" in integration_params:
|
||||||
Supabase.supabase_table_name = integration_params["table_name"]
|
Supabase.supabase_table_name = integration_params["table_name"]
|
||||||
|
|
||||||
def exception_type(model, original_exception):
|
def exception_type(model, original_exception, custom_llm_provider):
|
||||||
global user_logger_fn
|
global user_logger_fn
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
try:
|
try:
|
||||||
|
@ -640,6 +640,17 @@ def exception_type(model, original_exception):
|
||||||
elif "CohereConnectionError" in exception_type: # cohere seems to fire these errors when we load test it (1k+ messages / min)
|
elif "CohereConnectionError" in exception_type: # cohere seems to fire these errors when we load test it (1k+ messages / min)
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise RateLimitError(f"CohereException - {original_exception.message}")
|
raise RateLimitError(f"CohereException - {original_exception.message}")
|
||||||
|
elif custom_llm_provider == "huggingface":
|
||||||
|
if hasattr(original_exception, "status_code"):
|
||||||
|
if original_exception.status_code == 401:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise AuthenticationError(f"HuggingfaceException - {original_exception.message}")
|
||||||
|
elif original_exception.status_code == 400:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise InvalidRequestError(f"HuggingfaceException - {original_exception.message}", f"{model}")
|
||||||
|
elif original_exception.status_code == 429:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise RateLimitError(f"HuggingfaceException - {original_exception.message}")
|
||||||
raise original_exception # base case - return the original exception
|
raise original_exception # base case - return the original exception
|
||||||
else:
|
else:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
@ -715,8 +726,9 @@ def get_secret(secret_name):
|
||||||
# wraps the completion stream to return the correct format for the model
|
# wraps the completion stream to return the correct format for the model
|
||||||
# replicate/anthropic/cohere
|
# replicate/anthropic/cohere
|
||||||
class CustomStreamWrapper:
|
class CustomStreamWrapper:
|
||||||
def __init__(self, completion_stream, model):
|
def __init__(self, completion_stream, model, custom_llm_provider=None):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.custom_llm_provider = custom_llm_provider
|
||||||
if model in litellm.cohere_models:
|
if model in litellm.cohere_models:
|
||||||
# cohere does not return an iterator, so we need to wrap it in one
|
# cohere does not return an iterator, so we need to wrap it in one
|
||||||
self.completion_stream = iter(completion_stream)
|
self.completion_stream = iter(completion_stream)
|
||||||
|
@ -746,6 +758,16 @@ class CustomStreamWrapper:
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def handle_huggingface_chunk(self, chunk):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if chunk.startswith('data:'):
|
||||||
|
data_json = json.loads(chunk[5:])
|
||||||
|
if "token" in data_json and "text" in data_json["token"]:
|
||||||
|
return data_json["token"]["text"]
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
return ""
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
completion_obj ={ "role": "assistant", "content": ""}
|
completion_obj ={ "role": "assistant", "content": ""}
|
||||||
if self.model in litellm.anthropic_models:
|
if self.model in litellm.anthropic_models:
|
||||||
|
@ -763,6 +785,9 @@ class CustomStreamWrapper:
|
||||||
elif self.model in litellm.cohere_models:
|
elif self.model in litellm.cohere_models:
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
completion_obj["content"] = chunk.text
|
completion_obj["content"] = chunk.text
|
||||||
|
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
|
||||||
|
chunk = next(self.completion_stream)
|
||||||
|
completion_obj["content"] = self.handle_huggingface_chunk(chunk)
|
||||||
# return this for all models
|
# return this for all models
|
||||||
return {"choices": [{"delta": completion_obj}]}
|
return {"choices": [{"delta": completion_obj}]}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue