Merge branch 'main' into main

This commit is contained in:
Krish Dholakia 2023-08-09 11:00:40 -07:00 committed by GitHub
commit 4278b183d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 1000 additions and 102 deletions

View file

@ -1,6 +1,5 @@
import os, openai, cohere, replicate, sys
import os, openai, sys
from typing import Any
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
from functools import partial
import dotenv, traceback, random, asyncio, time
from copy import deepcopy
@ -8,15 +7,9 @@ import litellm
from litellm import client, logging, exception_type, timeout, get_optional_params
import tiktoken
encoding = tiktoken.get_encoding("cl100k_base")
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
) # for exponential backoff
from litellm.utils import get_secret
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
new_response = {
"choices": [
{
@ -28,9 +21,7 @@ new_response = {
}
]
}
# TODO move this to utils.py
# TODO add translations
# TODO see if this worked - model_name == krrish
####### COMPLETION ENDPOINTS ################
#############################################
async def acompletion(*args, **kwargs):
@ -52,7 +43,8 @@ def completion(
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
# Optional liteLLM function params
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
hugging_face = False, replicate=False,
):
try:
global new_response
@ -61,13 +53,16 @@ def completion(
optional_params = get_optional_params(
functions=functions, function_call=function_call,
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
# params to identify the model
model=model, replicate=replicate, hugging_face=hugging_face
)
if azure == True:
# azure configs
openai.api_type = "azure"
openai.api_base = litellm.api_base if litellm.api_base is not None else get_secret("AZURE_API_BASE")
openai.api_version = litellm.api_version if litellm.api_version is not None else get_secret("AZURE_API_VERSION")
# set key
if api_key:
openai.api_key = api_key
elif litellm.azure_key:
@ -92,6 +87,7 @@ def completion(
)
elif model in litellm.open_ai_chat_completion_models:
openai.api_type = "openai"
# note: if a user sets a custom base - we should ensure this works
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1"
openai.api_version = None
if litellm.organization:
@ -154,7 +150,10 @@ def completion(
model_response["model"] = model
model_response["usage"] = response["usage"]
response = model_response
elif "replicate" in model:
elif "replicate" in model or replicate == True:
# import replicate/if it fails then pip install replicate
install_and_import("replicate")
import replicate
# replicate defaults to os.environ.get("REPLICATE_API_TOKEN")
# checking in case user set it to REPLICATE_API_KEY instead
if not get_secret("REPLICATE_API_TOKEN") and get_secret("REPLICATE_API_KEY"):
@ -175,6 +174,11 @@ def completion(
output = replicate.run(
model,
input=input)
if 'stream' in optional_params and optional_params['stream'] == True:
# don't try to access stream object,
# let the stream handler know this is replicate
response = CustomStreamWrapper(output, "replicate")
return response
response = ""
for item in output:
response += item
@ -194,6 +198,10 @@ def completion(
}
response = model_response
elif model in litellm.anthropic_models:
# import anthropic/if it fails then pip install anthropic
install_and_import("anthropic")
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
#anthropic defaults to os.environ.get("ANTHROPIC_API_KEY")
if api_key:
os.environ["ANTHROPIC_API_KEY"] = api_key
@ -220,8 +228,14 @@ def completion(
completion = anthropic.completions.create(
model=model,
prompt=prompt,
max_tokens_to_sample=max_tokens_to_sample
max_tokens_to_sample=max_tokens_to_sample,
**optional_params
)
if 'stream' in optional_params and optional_params['stream'] == True:
# don't try to access stream object,
response = CustomStreamWrapper(completion, model)
return response
completion_response = completion.completion
## LOGGING
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
@ -274,6 +288,9 @@ def completion(
**optional_params
)
elif model in litellm.cohere_models:
# import cohere/if it fails then pip install cohere
install_and_import("cohere")
import cohere
if api_key:
cohere_key = api_key
elif litellm.cohere_key:
@ -287,8 +304,14 @@ def completion(
## COMPLETION CALL
response = co.generate(
model=model,
prompt = prompt
prompt = prompt,
**optional_params
)
if 'stream' in optional_params and optional_params['stream'] == True:
# don't try to access stream object,
response = CustomStreamWrapper(response, model)
return response
completion_response = response[0].text
## LOGGING
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
@ -304,6 +327,33 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
elif hugging_face == True:
import requests
API_URL = f"https://api-inference.huggingface.co/models/{model}"
HF_TOKEN = get_secret("HF_TOKEN")
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
input_payload = {"inputs": prompt}
response = requests.post(API_URL, headers=headers, json=input_payload)
completion_response = response.json()[0]['generated_text']
## LOGGING
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(completion_response))
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
else:
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)