mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
adding support for meta-llama-2
This commit is contained in:
parent
b5875cc4bd
commit
6aff47083b
12 changed files with 220 additions and 43 deletions
|
@ -7,6 +7,7 @@ import litellm
|
|||
from litellm import client, logging, exception_type, timeout, get_optional_params, get_litellm_params
|
||||
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args
|
||||
from .llms.anthropic import AnthropicLLM
|
||||
from .llms.huggingface_restapi import HuggingfaceRestAPILLM
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
@ -222,7 +223,6 @@ def completion(
|
|||
response = CustomStreamWrapper(model_response, model)
|
||||
return response
|
||||
response = model_response
|
||||
|
||||
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
|
||||
openai.api_type = "openai"
|
||||
# not sure if this will work after someone first uses another API
|
||||
|
@ -305,37 +305,15 @@ def completion(
|
|||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
response = model_response
|
||||
elif custom_llm_provider == "huggingface":
|
||||
import requests
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{model}"
|
||||
HF_TOKEN = get_secret("HF_TOKEN")
|
||||
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
||||
input_payload = {"inputs": prompt}
|
||||
response = requests.post(API_URL, headers=headers, json=input_payload)
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": response.text}, logger_fn=logger_fn)
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
raise Exception(response["error"])
|
||||
json_response = response.json()
|
||||
if 'error' in json_response: # raise HF errors when they exist
|
||||
raise Exception(json_response['error'])
|
||||
|
||||
completion_response = json_response[0]['generated_text']
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(encoding.encode(completion_response))
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = completion_response
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
elif model in litellm.huggingface_models or custom_llm_provider == "huggingface":
|
||||
custom_llm_provider = "huggingface"
|
||||
huggingface_key = api_key if api_key is not None else litellm.huggingface_key
|
||||
huggingface_client = HuggingfaceRestAPILLM(encoding=encoding, api_key=huggingface_key)
|
||||
model_response = huggingface_client.completion(model=model, messages=messages, custom_api_base=custom_api_base, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn)
|
||||
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="huggingface")
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "together_ai":
|
||||
import requests
|
||||
|
@ -383,7 +361,7 @@ def completion(
|
|||
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"litellm_params": litellm_params, "optional_params": optional_params}, logger_fn=logger_fn)
|
||||
|
||||
chat_model = ChatModel.from_pretrained(model)
|
||||
|
||||
|
@ -434,13 +412,13 @@ def completion(
|
|||
## LOGGING
|
||||
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
||||
args = locals()
|
||||
raise ValueError(f"Invalid completion model args passed in. Check your input - {args}")
|
||||
raise ValueError(f"Unable to map your input to a model. Check your input - {args}")
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens}, logger_fn=logger_fn, exception=e)
|
||||
## Map to OpenAI Exception
|
||||
raise exception_type(model=model, original_exception=e)
|
||||
raise exception_type(model=model, custom_llm_provider=custom_llm_provider, original_exception=e)
|
||||
|
||||
def batch_completion(*args, **kwargs):
|
||||
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue