mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
4278b183d0
18 changed files with 1000 additions and 102 deletions
|
@ -1,6 +1,5 @@
|
|||
import os, openai, cohere, replicate, sys
|
||||
import os, openai, sys
|
||||
from typing import Any
|
||||
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
||||
from functools import partial
|
||||
import dotenv, traceback, random, asyncio, time
|
||||
from copy import deepcopy
|
||||
|
@ -8,15 +7,9 @@ import litellm
|
|||
from litellm import client, logging, exception_type, timeout, get_optional_params
|
||||
import tiktoken
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
) # for exponential backoff
|
||||
from litellm.utils import get_secret
|
||||
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
|
||||
new_response = {
|
||||
"choices": [
|
||||
{
|
||||
|
@ -28,9 +21,7 @@ new_response = {
|
|||
}
|
||||
]
|
||||
}
|
||||
# TODO move this to utils.py
|
||||
# TODO add translations
|
||||
# TODO see if this worked - model_name == krrish
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
#############################################
|
||||
async def acompletion(*args, **kwargs):
|
||||
|
@ -52,7 +43,8 @@ def completion(
|
|||
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
|
||||
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
|
||||
# Optional liteLLM function params
|
||||
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False
|
||||
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
|
||||
hugging_face = False, replicate=False,
|
||||
):
|
||||
try:
|
||||
global new_response
|
||||
|
@ -61,13 +53,16 @@ def completion(
|
|||
optional_params = get_optional_params(
|
||||
functions=functions, function_call=function_call,
|
||||
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id
|
||||
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
|
||||
# params to identify the model
|
||||
model=model, replicate=replicate, hugging_face=hugging_face
|
||||
)
|
||||
if azure == True:
|
||||
# azure configs
|
||||
openai.api_type = "azure"
|
||||
openai.api_base = litellm.api_base if litellm.api_base is not None else get_secret("AZURE_API_BASE")
|
||||
openai.api_version = litellm.api_version if litellm.api_version is not None else get_secret("AZURE_API_VERSION")
|
||||
# set key
|
||||
if api_key:
|
||||
openai.api_key = api_key
|
||||
elif litellm.azure_key:
|
||||
|
@ -92,6 +87,7 @@ def completion(
|
|||
)
|
||||
elif model in litellm.open_ai_chat_completion_models:
|
||||
openai.api_type = "openai"
|
||||
# note: if a user sets a custom base - we should ensure this works
|
||||
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1"
|
||||
openai.api_version = None
|
||||
if litellm.organization:
|
||||
|
@ -154,7 +150,10 @@ def completion(
|
|||
model_response["model"] = model
|
||||
model_response["usage"] = response["usage"]
|
||||
response = model_response
|
||||
elif "replicate" in model:
|
||||
elif "replicate" in model or replicate == True:
|
||||
# import replicate/if it fails then pip install replicate
|
||||
install_and_import("replicate")
|
||||
import replicate
|
||||
# replicate defaults to os.environ.get("REPLICATE_API_TOKEN")
|
||||
# checking in case user set it to REPLICATE_API_KEY instead
|
||||
if not get_secret("REPLICATE_API_TOKEN") and get_secret("REPLICATE_API_KEY"):
|
||||
|
@ -175,6 +174,11 @@ def completion(
|
|||
output = replicate.run(
|
||||
model,
|
||||
input=input)
|
||||
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||
# don't try to access stream object,
|
||||
# let the stream handler know this is replicate
|
||||
response = CustomStreamWrapper(output, "replicate")
|
||||
return response
|
||||
response = ""
|
||||
for item in output:
|
||||
response += item
|
||||
|
@ -194,6 +198,10 @@ def completion(
|
|||
}
|
||||
response = model_response
|
||||
elif model in litellm.anthropic_models:
|
||||
# import anthropic/if it fails then pip install anthropic
|
||||
install_and_import("anthropic")
|
||||
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
||||
|
||||
#anthropic defaults to os.environ.get("ANTHROPIC_API_KEY")
|
||||
if api_key:
|
||||
os.environ["ANTHROPIC_API_KEY"] = api_key
|
||||
|
@ -220,8 +228,14 @@ def completion(
|
|||
completion = anthropic.completions.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens_to_sample=max_tokens_to_sample
|
||||
max_tokens_to_sample=max_tokens_to_sample,
|
||||
**optional_params
|
||||
)
|
||||
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(completion, model)
|
||||
return response
|
||||
|
||||
completion_response = completion.completion
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||
|
@ -274,6 +288,9 @@ def completion(
|
|||
**optional_params
|
||||
)
|
||||
elif model in litellm.cohere_models:
|
||||
# import cohere/if it fails then pip install cohere
|
||||
install_and_import("cohere")
|
||||
import cohere
|
||||
if api_key:
|
||||
cohere_key = api_key
|
||||
elif litellm.cohere_key:
|
||||
|
@ -287,8 +304,14 @@ def completion(
|
|||
## COMPLETION CALL
|
||||
response = co.generate(
|
||||
model=model,
|
||||
prompt = prompt
|
||||
prompt = prompt,
|
||||
**optional_params
|
||||
)
|
||||
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(response, model)
|
||||
return response
|
||||
|
||||
completion_response = response[0].text
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||
|
@ -304,6 +327,33 @@ def completion(
|
|||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
response = model_response
|
||||
elif hugging_face == True:
|
||||
import requests
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{model}"
|
||||
HF_TOKEN = get_secret("HF_TOKEN")
|
||||
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
|
||||
input_payload = {"inputs": prompt}
|
||||
response = requests.post(API_URL, headers=headers, json=input_payload)
|
||||
|
||||
completion_response = response.json()[0]['generated_text']
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(encoding.encode(completion_response))
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = completion_response
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
response = model_response
|
||||
else:
|
||||
## LOGGING
|
||||
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue