adding support for meta-llama-2

This commit is contained in:
Krrish Dholakia 2023-08-14 13:06:33 -07:00
parent b5875cc4bd
commit 6aff47083b
12 changed files with 220 additions and 43 deletions

View file

@ -7,6 +7,7 @@ import litellm
from litellm import client, logging, exception_type, timeout, get_optional_params, get_litellm_params
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args
from .llms.anthropic import AnthropicLLM
from .llms.huggingface_restapi import HuggingfaceRestAPILLM
import tiktoken
from concurrent.futures import ThreadPoolExecutor
encoding = tiktoken.get_encoding("cl100k_base")
@ -222,7 +223,6 @@ def completion(
response = CustomStreamWrapper(model_response, model)
return response
response = model_response
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
openai.api_type = "openai"
# not sure if this will work after someone first uses another API
@ -305,37 +305,15 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
elif custom_llm_provider == "huggingface":
import requests
API_URL = f"https://api-inference.huggingface.co/models/{model}"
HF_TOKEN = get_secret("HF_TOKEN")
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
input_payload = {"inputs": prompt}
response = requests.post(API_URL, headers=headers, json=input_payload)
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": response.text}, logger_fn=logger_fn)
if isinstance(response, dict) and "error" in response:
raise Exception(response["error"])
json_response = response.json()
if 'error' in json_response: # raise HF errors when they exist
raise Exception(json_response['error'])
completion_response = json_response[0]['generated_text']
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(completion_response))
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
elif model in litellm.huggingface_models or custom_llm_provider == "huggingface":
custom_llm_provider = "huggingface"
huggingface_key = api_key if api_key is not None else litellm.huggingface_key
huggingface_client = HuggingfaceRestAPILLM(encoding=encoding, api_key=huggingface_key)
model_response = huggingface_client.completion(model=model, messages=messages, custom_api_base=custom_api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn)
if 'stream' in optional_params and optional_params['stream'] == True:
# don't try to access stream object,
response = CustomStreamWrapper(model_response, model, custom_llm_provider="huggingface")
return response
response = model_response
elif custom_llm_provider == "together_ai":
import requests
@ -383,7 +361,7 @@ def completion(
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"litellm_params": litellm_params, "optional_params": optional_params}, logger_fn=logger_fn)
chat_model = ChatModel.from_pretrained(model)
@ -434,13 +412,13 @@ def completion(
## LOGGING
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
args = locals()
raise ValueError(f"Invalid completion model args passed in. Check your input - {args}")
raise ValueError(f"Unable to map your input to a model. Check your input - {args}")
return response
except Exception as e:
## LOGGING
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens}, logger_fn=logger_fn, exception=e)
## Map to OpenAI Exception
raise exception_type(model=model, original_exception=e)
raise exception_type(model=model, custom_llm_provider=custom_llm_provider, original_exception=e)
def batch_completion(*args, **kwargs):
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")