refactor: add black formatting

This commit is contained in:
Krrish Dholakia 2023-12-25 14:10:38 +05:30
parent b87d630b0a
commit 4905929de3
156 changed files with 19723 additions and 10869 deletions

View file

@ -1,9 +1,13 @@
repos:
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 3.8.4 # The version of flake8 to use
hooks:
- id: flake8
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/|^litellm/proxy/tests/
additional_dependencies: [flake8-print]
files: litellm/.*\.py
- repo: local

View file

@ -9,33 +9,37 @@ import os
# Define the list of models to benchmark
# select any LLM listed here: https://docs.litellm.ai/docs/providers
models = ['gpt-3.5-turbo', 'claude-2']
models = ["gpt-3.5-turbo", "claude-2"]
# Enter LLM API keys
# https://docs.litellm.ai/docs/providers
os.environ['OPENAI_API_KEY'] = ""
os.environ['ANTHROPIC_API_KEY'] = ""
os.environ["OPENAI_API_KEY"] = ""
os.environ["ANTHROPIC_API_KEY"] = ""
# List of questions to benchmark (replace with your questions)
questions = [
"When will BerriAI IPO?",
"When will LiteLLM hit $100M ARR?"
]
questions = ["When will BerriAI IPO?", "When will LiteLLM hit $100M ARR?"]
# Enter your system prompt here
# Enter your system prompt here
system_prompt = """
You are LiteLLMs helpful assistant
"""
@click.command()
@click.option('--system-prompt', default="You are a helpful assistant that can answer questions.", help="System prompt for the conversation.")
@click.option(
"--system-prompt",
default="You are a helpful assistant that can answer questions.",
help="System prompt for the conversation.",
)
def main(system_prompt):
for question in questions:
data = [] # Data for the current question
with tqdm(total=len(models)) as pbar:
for model in models:
colored_description = colored(f"Running question: {question} for model: {model}", 'green')
colored_description = colored(
f"Running question: {question} for model: {model}", "green"
)
pbar.set_description(colored_description)
start_time = time.time()
@ -44,35 +48,43 @@ def main(system_prompt):
max_tokens=500,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
{"role": "user", "content": question},
],
)
end = time.time()
total_time = end - start_time
cost = completion_cost(completion_response=response)
raw_response = response['choices'][0]['message']['content']
raw_response = response["choices"][0]["message"]["content"]
data.append({
'Model': colored(model, 'light_blue'),
'Response': raw_response, # Colorize the response
'ResponseTime': colored(f"{total_time:.2f} seconds", "red"),
'Cost': colored(f"${cost:.6f}", 'green'), # Colorize the cost
})
data.append(
{
"Model": colored(model, "light_blue"),
"Response": raw_response, # Colorize the response
"ResponseTime": colored(f"{total_time:.2f} seconds", "red"),
"Cost": colored(f"${cost:.6f}", "green"), # Colorize the cost
}
)
pbar.update(1)
# Separate headers from the data
headers = ['Model', 'Response', 'Response Time (seconds)', 'Cost ($)']
headers = ["Model", "Response", "Response Time (seconds)", "Cost ($)"]
colwidths = [15, 80, 15, 10]
# Create a nicely formatted table for the current question
table = tabulate([list(d.values()) for d in data], headers, tablefmt="grid", maxcolwidths=colwidths)
table = tabulate(
[list(d.values()) for d in data],
headers,
tablefmt="grid",
maxcolwidths=colwidths,
)
# Print the table for the current question
colored_question = colored(question, 'green')
colored_question = colored(question, "green")
click.echo(f"\nBenchmark Results for '{colored_question}':")
click.echo(table) # Display the formatted table
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -1,25 +1,22 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import litellm
from litellm import embedding, completion, completion_cost
from autoevals.llm import *
###################
import litellm
# litellm completion call
question = "which country has the highest population"
response = litellm.completion(
model = "gpt-3.5-turbo",
messages = [
{
"role": "user",
"content": question
}
],
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": question}],
)
print(response)
# use the auto eval Factuality() evaluator
@ -27,9 +24,11 @@ print(response)
print("calling evaluator")
evaluator = Factuality()
result = evaluator(
output=response.choices[0]["message"]["content"], # response from litellm.completion()
expected="India", # expected output
input=question # question passed to litellm.completion
output=response.choices[0]["message"][
"content"
], # response from litellm.completion()
expected="India", # expected output
input=question, # question passed to litellm.completion
)
print(result)
print(result)

View file

@ -4,9 +4,10 @@ from flask_cors import CORS
import traceback
import litellm
from util import handle_error
from litellm import completion
import os, dotenv, time
from litellm import completion
import os, dotenv, time
import json
dotenv.load_dotenv()
# TODO: set your keys in .env or here:
@ -19,57 +20,72 @@ verbose = True
# litellm.caching_with_models = True # CACHING: caching_with_models Keys in the cache are messages + model. - to learn more: https://docs.litellm.ai/docs/caching/
######### PROMPT LOGGING ##########
os.environ["PROMPTLAYER_API_KEY"] = "" # set your promptlayer key here - https://promptlayer.com/
os.environ[
"PROMPTLAYER_API_KEY"
] = "" # set your promptlayer key here - https://promptlayer.com/
# set callbacks
litellm.success_callback = ["promptlayer"]
############ HELPER FUNCTIONS ###################################
def print_verbose(print_statement):
if verbose:
print(print_statement)
app = Flask(__name__)
CORS(app)
@app.route('/')
@app.route("/")
def index():
return 'received!', 200
return "received!", 200
def data_generator(response):
for chunk in response:
yield f"data: {json.dumps(chunk)}\n\n"
@app.route('/chat/completions', methods=["POST"])
@app.route("/chat/completions", methods=["POST"])
def api_completion():
data = request.json
start_time = time.time()
if data.get('stream') == "True":
data['stream'] = True # convert to boolean
start_time = time.time()
if data.get("stream") == "True":
data["stream"] = True # convert to boolean
try:
if "prompt" not in data:
raise ValueError("data needs to have prompt")
data["model"] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct
data[
"model"
] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct
# COMPLETION CALL
system_prompt = "Only respond to questions about code. Say 'I don't know' to anything outside of that."
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": data.pop("prompt")}]
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": data.pop("prompt")},
]
data["messages"] = messages
print(f"data: {data}")
response = completion(**data)
## LOG SUCCESS
end_time = time.time()
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return Response(data_generator(response), mimetype='text/event-stream')
end_time = time.time()
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
return Response(data_generator(response), mimetype="text/event-stream")
except Exception as e:
# call handle_error function
print_verbose(f"Got Error api_completion(): {traceback.format_exc()}")
## LOG FAILURE
end_time = time.time()
end_time = time.time()
traceback_exception = traceback.format_exc()
return handle_error(data=data)
return response
@app.route('/get_models', methods=["POST"])
@app.route("/get_models", methods=["POST"])
def get_models():
try:
return litellm.model_list
@ -78,7 +94,8 @@ def get_models():
response = {"error": str(e)}
return response, 200
if __name__ == "__main__":
from waitress import serve
serve(app, host="0.0.0.0", port=4000, threads=500)
if __name__ == "__main__":
from waitress import serve
serve(app, host="0.0.0.0", port=4000, threads=500)

View file

@ -3,27 +3,28 @@ from urllib.parse import urlparse, parse_qs
def get_next_url(response):
"""
"""
Function to get 'next' url from Link header
:param response: response from requests
:return: next url or None
"""
if 'link' not in response.headers:
return None
headers = response.headers
if "link" not in response.headers:
return None
headers = response.headers
next_url = headers['Link']
print(next_url)
start_index = next_url.find("<")
end_index = next_url.find(">")
next_url = headers["Link"]
print(next_url)
start_index = next_url.find("<")
end_index = next_url.find(">")
return next_url[1:end_index]
return next_url[1:end_index]
def get_models(url):
"""
Function to retrieve all models from paginated endpoint
:param url: base url to make GET request
:return: list of all models
Function to retrieve all models from paginated endpoint
:param url: base url to make GET request
:return: list of all models
"""
models = []
while url:
@ -36,19 +37,21 @@ def get_models(url):
models.extend(payload)
return models
def get_cleaned_models(models):
"""
Function to clean retrieved models
:param models: list of retrieved models
:return: list of cleaned models
Function to clean retrieved models
:param models: list of retrieved models
:return: list of cleaned models
"""
cleaned_models = []
for model in models:
cleaned_models.append(model["id"])
return cleaned_models
# Get text-generation models
url = 'https://huggingface.co/api/models?filter=text-generation-inference'
url = "https://huggingface.co/api/models?filter=text-generation-inference"
text_generation_models = get_models(url)
cleaned_text_generation_models = get_cleaned_models(text_generation_models)
@ -56,7 +59,7 @@ print(cleaned_text_generation_models)
# Get conversational models
url = 'https://huggingface.co/api/models?filter=conversational'
url = "https://huggingface.co/api/models?filter=conversational"
conversational_models = get_models(url)
cleaned_conversational_models = get_cleaned_models(conversational_models)
@ -65,19 +68,23 @@ print(cleaned_conversational_models)
def write_to_txt(cleaned_models, filename):
"""
Function to write the contents of a list to a text file
:param cleaned_models: list of cleaned models
:param filename: name of the text file
Function to write the contents of a list to a text file
:param cleaned_models: list of cleaned models
:param filename: name of the text file
"""
with open(filename, 'w') as f:
with open(filename, "w") as f:
for item in cleaned_models:
f.write("%s\n" % item)
# Write contents of cleaned_text_generation_models to text_generation_models.txt
write_to_txt(cleaned_text_generation_models, 'huggingface_llms_metadata/hf_text_generation_models.txt')
write_to_txt(
cleaned_text_generation_models,
"huggingface_llms_metadata/hf_text_generation_models.txt",
)
# Write contents of cleaned_conversational_models to conversational_models.txt
write_to_txt(cleaned_conversational_models, 'huggingface_llms_metadata/hf_conversational_models.txt')
write_to_txt(
cleaned_conversational_models,
"huggingface_llms_metadata/hf_conversational_models.txt",
)

View file

@ -1,4 +1,3 @@
import openai
api_base = f"http://0.0.0.0:8000"
@ -8,29 +7,29 @@ openai.api_key = "temp-key"
print(openai.api_base)
print(f'LiteLLM: response from proxy with streaming')
print(f"LiteLLM: response from proxy with streaming")
response = openai.ChatCompletion.create(
model="ollama/llama2",
messages = [
model="ollama/llama2",
messages=[
{
"role": "user",
"content": "this is a test request, acknowledge that you got it"
"content": "this is a test request, acknowledge that you got it",
}
],
stream=True
stream=True,
)
for chunk in response:
print(f'LiteLLM: streaming response from proxy {chunk}')
print(f"LiteLLM: streaming response from proxy {chunk}")
response = openai.ChatCompletion.create(
model="ollama/llama2",
messages = [
model="ollama/llama2",
messages=[
{
"role": "user",
"content": "this is a test request, acknowledge that you got it"
"content": "this is a test request, acknowledge that you got it",
}
]
],
)
print(f'LiteLLM: response from proxy {response}')
print(f"LiteLLM: response from proxy {response}")

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router
import litellm
litellm.set_verbose=False
litellm.set_verbose = False
os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
}
}]
model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"]
file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = []
for file_path in file_paths:
try:
print(file_path)
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
content = file.read()
questions.append(content)
except FileNotFoundError as e:
@ -59,10 +68,9 @@ for file_path in file_paths:
# print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
import concurrent.futures
import random
@ -74,10 +82,18 @@ def make_openai_completion(question):
try:
start_time = time.time()
import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000"
) # base_url="http://0.0.0.0:8000",
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}],
messages=[
{
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
}
],
)
print(response)
end_time = time.time()
@ -92,11 +108,10 @@ def make_openai_completion(question):
except Exception as e:
# Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file:
error_log_file.write(
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
return None
# Number of concurrent calls (you can adjust this)
concurrent_calls = 100
@ -133,4 +148,3 @@ with open("request_log.txt", "r") as log_file:
with open("error_log.txt", "r") as error_log_file:
print("\nError Log:\n", error_log_file.read())

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router
import litellm
litellm.set_verbose=False
litellm.set_verbose = False
# os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
}
}]
model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"]
file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = []
for file_path in file_paths:
try:
print(file_path)
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
content = file.read()
questions.append(content)
except FileNotFoundError as e:
@ -59,10 +68,9 @@ for file_path in file_paths:
# print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
import concurrent.futures
import random
@ -76,9 +84,12 @@ def make_openai_completion(question):
import requests
data = {
'model': 'gpt-3.5-turbo',
'messages': [
{'role': 'system', 'content': f'You are a helpful assistant. Answer this question{question}'},
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
},
],
}
response = requests.post("http://0.0.0.0:8000/queue/request", json=data)
@ -89,8 +100,8 @@ def make_openai_completion(question):
log_file.write(
f"Question: {question[:100]}\nResponse ID: {response.get('id', 'N/A')} Url: {response.get('url', 'N/A')}\nTime: {end_time - start_time:.2f} seconds\n\n"
)
# polling the url
# polling the url
while True:
try:
url = response["url"]
@ -107,7 +118,9 @@ def make_openai_completion(question):
)
break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}")
print(
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}"
)
time.sleep(0.5)
except Exception as e:
print("got exception in polling", e)
@ -117,11 +130,10 @@ def make_openai_completion(question):
except Exception as e:
# Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file:
error_log_file.write(
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
return None
# Number of concurrent calls (you can adjust this)
concurrent_calls = 10
@ -142,7 +154,7 @@ successful_calls = 0
failed_calls = 0
for future in futures:
if future.done():
if future.done():
if future.result() is not None:
successful_calls += 1
else:
@ -152,4 +164,3 @@ print(f"Load test Summary:")
print(f"Total Requests: {concurrent_calls}")
print(f"Successful Calls: {successful_calls}")
print(f"Failed Calls: {failed_calls}")

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router
import litellm
litellm.set_verbose=False
litellm.set_verbose = False
os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
}
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
}
}]
model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"]
file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = []
for file_path in file_paths:
try:
print(file_path)
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
content = file.read()
questions.append(content)
except FileNotFoundError as e:
@ -59,10 +68,9 @@ for file_path in file_paths:
# print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
import concurrent.futures
import random
@ -75,7 +83,12 @@ def make_openai_completion(question):
start_time = time.time()
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}],
messages=[
{
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
}
],
)
print(response)
end_time = time.time()
@ -90,11 +103,10 @@ def make_openai_completion(question):
except Exception as e:
# Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file:
error_log_file.write(
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
return None
# Number of concurrent calls (you can adjust this)
concurrent_calls = 150
@ -131,4 +143,3 @@ with open("request_log.txt", "r") as log_file:
with open("error_log.txt", "r") as error_log_file:
print("\nError Log:\n", error_log_file.read())

View file

@ -9,9 +9,15 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
_async_input_callback: List[
Callable
] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[
Union[str, Callable]
] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[
Callable
] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
email: Optional[
@ -42,20 +48,88 @@ aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None
use_client: bool = False
logging: bool = True
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
cache: Optional[
Cache
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
_openai_completion_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
_litellm_completion_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
_current_cost = 0 # private variable, used if max budget is set
max_budget: float = 0.0 # set the max budget across all providers
_openai_completion_params = [
"functions",
"function_call",
"temperature",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"request_timeout",
"api_base",
"api_version",
"api_key",
"deployment_id",
"organization",
"base_url",
"default_headers",
"timeout",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
]
_litellm_completion_params = [
"metadata",
"acompletion",
"caching",
"mock_response",
"api_key",
"api_version",
"api_base",
"force_timeout",
"logger_fn",
"verbose",
"custom_llm_provider",
"litellm_logging_obj",
"litellm_call_id",
"use_client",
"id",
"fallbacks",
"azure",
"headers",
"model_list",
"num_retries",
"context_window_fallback_dict",
"roles",
"final_prompt_value",
"bos_token",
"eos_token",
"request_timeout",
"complete_response",
"self",
"client",
"rpm",
"tpm",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",
"model_info",
"proxy_server_request",
"preset_cache_key",
]
_current_cost = 0 # private variable, used if max budget is set
error_logs: Dict = {}
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
@ -66,23 +140,35 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None
allowed_fails: int = 0
####### SECRET MANAGERS #####################
secret_manager_client: Optional[Any] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
secret_manager_client: Optional[
Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
#############################################
def get_model_cost_map(url: str):
try:
with requests.get(url, timeout=5) as response: # set a 5 second timeout for the get request
response.raise_for_status() # Raise an exception if the request is unsuccessful
with requests.get(
url, timeout=5
) as response: # set a 5 second timeout for the get request
response.raise_for_status() # Raise an exception if the request is unsuccessful
content = response.json()
return content
except Exception as e:
import importlib.resources
import json
with importlib.resources.open_text("litellm", "model_prices_and_context_window_backup.json") as f:
with importlib.resources.open_text(
"litellm", "model_prices_and_context_window_backup.json"
) as f:
content = json.load(f)
return content
model_cost = get_model_cost_map(url=model_cost_map_url)
custom_prompt_dict:Dict[str, dict] = {}
custom_prompt_dict: Dict[str, dict] = {}
####### THREAD-SPECIFIC DATA ###################
class MyLocal(threading.local):
def __init__(self):
@ -123,56 +209,51 @@ bedrock_models: List = []
deepinfra_models: List = []
perplexity_models: List = []
for key, value in model_cost.items():
if value.get('litellm_provider') == 'openai':
if value.get("litellm_provider") == "openai":
open_ai_chat_completion_models.append(key)
elif value.get('litellm_provider') == 'text-completion-openai':
elif value.get("litellm_provider") == "text-completion-openai":
open_ai_text_completion_models.append(key)
elif value.get('litellm_provider') == 'cohere':
elif value.get("litellm_provider") == "cohere":
cohere_models.append(key)
elif value.get('litellm_provider') == 'anthropic':
elif value.get("litellm_provider") == "anthropic":
anthropic_models.append(key)
elif value.get('litellm_provider') == 'openrouter':
elif value.get("litellm_provider") == "openrouter":
openrouter_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-text-models':
elif value.get("litellm_provider") == "vertex_ai-text-models":
vertex_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
elif value.get("litellm_provider") == "vertex_ai-code-text-models":
vertex_code_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-language-models':
elif value.get("litellm_provider") == "vertex_ai-language-models":
vertex_language_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-vision-models':
elif value.get("litellm_provider") == "vertex_ai-vision-models":
vertex_vision_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-chat-models':
elif value.get("litellm_provider") == "vertex_ai-chat-models":
vertex_chat_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models':
elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
vertex_code_chat_models.append(key)
elif value.get('litellm_provider') == 'ai21':
elif value.get("litellm_provider") == "ai21":
ai21_models.append(key)
elif value.get('litellm_provider') == 'nlp_cloud':
elif value.get("litellm_provider") == "nlp_cloud":
nlp_cloud_models.append(key)
elif value.get('litellm_provider') == 'aleph_alpha':
elif value.get("litellm_provider") == "aleph_alpha":
aleph_alpha_models.append(key)
elif value.get('litellm_provider') == 'bedrock':
elif value.get("litellm_provider") == "bedrock":
bedrock_models.append(key)
elif value.get('litellm_provider') == 'deepinfra':
elif value.get("litellm_provider") == "deepinfra":
deepinfra_models.append(key)
elif value.get('litellm_provider') == 'perplexity':
elif value.get("litellm_provider") == "perplexity":
perplexity_models.append(key)
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
openai_compatible_endpoints: List = [
"api.perplexity.ai",
"api.perplexity.ai",
"api.endpoints.anyscale.com/v1",
"api.deepinfra.com/v1/openai",
"api.mistral.ai/v1"
"api.mistral.ai/v1",
]
# this is maintained for Exception Mapping
openai_compatible_providers: List = [
"anyscale",
"mistral",
"deepinfra",
"perplexity"
]
openai_compatible_providers: List = ["anyscale", "mistral", "deepinfra", "perplexity"]
# well supported replicate llms
@ -209,23 +290,18 @@ huggingface_models: List = [
together_ai_models: List = [
# llama llms - chat
"togethercomputer/llama-2-70b-chat",
# llama llms - language / instruct
# llama llms - language / instruct
"togethercomputer/llama-2-70b",
"togethercomputer/LLaMA-2-7B-32K",
"togethercomputer/Llama-2-7B-32K-Instruct",
"togethercomputer/llama-2-7b",
# falcon llms
"togethercomputer/falcon-40b-instruct",
"togethercomputer/falcon-7b-instruct",
# alpaca
"togethercomputer/alpaca-7b",
# chat llms
"HuggingFaceH4/starchat-alpha",
# code llms
"togethercomputer/CodeLlama-34b",
"togethercomputer/CodeLlama-34b-Instruct",
@ -234,29 +310,27 @@ together_ai_models: List = [
"NumbersStation/nsql-llama-2-7B",
"WizardLM/WizardCoder-15B-V1.0",
"WizardLM/WizardCoder-Python-34B-V1.0",
# language llms
"NousResearch/Nous-Hermes-Llama2-13b",
"Austism/chronos-hermes-13b",
"upstage/SOLAR-0-70b-16bit",
"WizardLM/WizardLM-70B-V1.0",
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
baseten_models: List = ["qvv0xeq", "q841o8w", "31dxrj3"] # FALCON 7B # WizardLM # Mosaic ML
baseten_models: List = [
"qvv0xeq",
"q841o8w",
"31dxrj3",
] # FALCON 7B # WizardLM # Mosaic ML
petals_models = [
"petals-team/StableBeluga2",
]
ollama_models = [
"llama2"
]
ollama_models = ["llama2"]
maritalk_models = [
"maritalk"
]
maritalk_models = ["maritalk"]
model_list = (
open_ai_chat_completion_models
@ -308,7 +382,7 @@ provider_list: List = [
"anyscale",
"mistral",
"maritalk",
"custom", # custom apis
"custom", # custom apis
]
models_by_provider: dict = {
@ -327,28 +401,28 @@ models_by_provider: dict = {
"ollama": ollama_models,
"deepinfra": deepinfra_models,
"perplexity": perplexity_models,
"maritalk": maritalk_models
"maritalk": maritalk_models,
}
# mapping for those models which have larger equivalents
# mapping for those models which have larger equivalents
longer_context_model_fallback_dict: dict = {
# openai chat completion models
"gpt-3.5-turbo": "gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301",
"gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613",
"gpt-4": "gpt-4-32k",
"gpt-4-0314": "gpt-4-32k-0314",
"gpt-4-0613": "gpt-4-32k-0613",
# anthropic
"claude-instant-1": "claude-2",
"gpt-3.5-turbo": "gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301",
"gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613",
"gpt-4": "gpt-4-32k",
"gpt-4-0314": "gpt-4-32k-0314",
"gpt-4-0613": "gpt-4-32k-0613",
# anthropic
"claude-instant-1": "claude-2",
"claude-instant-1.2": "claude-2",
# vertexai
"chat-bison": "chat-bison-32k",
"chat-bison@001": "chat-bison-32k",
"codechat-bison": "codechat-bison-32k",
"codechat-bison": "codechat-bison-32k",
"codechat-bison@001": "codechat-bison-32k",
# openrouter
"openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k",
# openrouter
"openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k",
"openrouter/anthropic/claude-instant-v1": "openrouter/anthropic/claude-2",
}
@ -357,20 +431,23 @@ open_ai_embedding_models: List = ["text-embedding-ada-002"]
cohere_embedding_models: List = [
"embed-english-v3.0",
"embed-english-light-v3.0",
"embed-multilingual-v3.0",
"embed-english-v2.0",
"embed-english-light-v2.0",
"embed-multilingual-v2.0",
"embed-multilingual-v3.0",
"embed-english-v2.0",
"embed-english-light-v2.0",
"embed-multilingual-v2.0",
]
bedrock_embedding_models: List = [
"amazon.titan-embed-text-v1",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
all_embedding_models = (
open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
)
####### IMAGE GENERATION MODELS ###################
openai_image_generation_models = [
"dall-e-2",
"dall-e-3"
]
openai_image_generation_models = ["dall-e-2", "dall-e-3"]
from .timeout import timeout
@ -394,11 +471,11 @@ from .utils import (
get_llm_provider,
completion_with_config,
register_model,
encode,
decode,
encode,
decode,
_calculate_retry_after,
_should_retry,
get_secret
get_secret,
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig
@ -415,7 +492,13 @@ from .llms.vertex_ai import VertexAIConfig
from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig, AmazonLlamaConfig
from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,
AmazonAnthropicConfig,
AmazonCohereConfig,
AmazonLlamaConfig,
)
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig
from .main import * # type: ignore
@ -429,13 +512,13 @@ from .exceptions import (
ServiceUnavailableError,
OpenAIError,
ContextWindowExceededError,
BudgetExceededError,
BudgetExceededError,
APIError,
Timeout,
APIConnectionError,
APIResponseValidationError,
UnprocessableEntityError
APIResponseValidationError,
UnprocessableEntityError,
)
from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server
from .router import Router
from .router import Router

View file

@ -1,8 +1,9 @@
set_verbose = False
def print_verbose(print_statement):
try:
if set_verbose:
print(print_statement) # noqa
print(print_statement) # noqa
except:
pass
pass

View file

@ -13,6 +13,7 @@ import inspect
import redis, litellm
from typing import List, Optional
def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis)
@ -23,32 +24,26 @@ def _get_redis_kwargs():
"retry",
}
include_args = [
"url"
]
include_args = ["url"]
available_args = [
x for x in arg_spec.args if x not in exclude_args
] + include_args
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
return available_args
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
return {
f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()
}
return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
def _redis_kwargs_from_environment():
mapping = _get_redis_env_kwarg_mapping()
return_dict = {}
return_dict = {}
for k, v in mapping.items():
value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
if value is not None:
value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
if value is not None:
return_dict[v] = value
return return_dict
@ -56,21 +51,26 @@ def _redis_kwargs_from_environment():
def get_redis_url_from_environment():
if "REDIS_URL" in os.environ:
return os.environ["REDIS_URL"]
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.")
raise ValueError(
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
)
if "REDIS_PASSWORD" in os.environ:
redis_password = f":{os.environ['REDIS_PASSWORD']}@"
else:
redis_password = ""
return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
return (
f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
)
def get_redis_client(**env_overrides):
### check if "os.environ/<key-name>" passed in
for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"):
for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"):
v = v.replace("os.environ/", "")
value = litellm.get_secret(v)
env_overrides[k] = value
@ -80,14 +80,14 @@ def get_redis_client(**env_overrides):
**env_overrides,
}
if "url" in redis_kwargs and redis_kwargs['url'] is not None:
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None)
return redis.Redis.from_url(**redis_kwargs)
elif "host" not in redis_kwargs or redis_kwargs['host'] is None:
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis.Redis(**redis_kwargs)
return redis.Redis(**redis_kwargs)

View file

@ -1,119 +1,166 @@
import os, json, time
import litellm
import litellm
from litellm.utils import ModelResponse
import requests, threading
from typing import Optional, Union, Literal
class BudgetManager:
def __init__(self, project_name: str, client_type: str = "local", api_base: Optional[str] = None):
def __init__(
self,
project_name: str,
client_type: str = "local",
api_base: Optional[str] = None,
):
self.client_type = client_type
self.project_name = project_name
self.api_base = api_base or "https://api.litellm.ai"
## load the data or init the initial dictionaries
self.load_data()
self.load_data()
def print_verbose(self, print_statement):
try:
if litellm.set_verbose:
import logging
logging.info(print_statement)
except:
pass
def load_data(self):
if self.client_type == "local":
# Check if user dict file exists
if os.path.isfile("user_cost.json"):
# Load the user dict
with open("user_cost.json", 'r') as json_file:
with open("user_cost.json", "r") as json_file:
self.user_dict = json.load(json_file)
else:
self.print_verbose("User Dictionary not found!")
self.user_dict = {}
self.user_dict = {}
self.print_verbose(f"user dict from local: {self.user_dict}")
elif self.client_type == "hosted":
# Load the user_dict from hosted db
url = self.api_base + "/get_budget"
headers = {'Content-Type': 'application/json'}
data = {
'project_name' : self.project_name
}
headers = {"Content-Type": "application/json"}
data = {"project_name": self.project_name}
response = requests.post(url, headers=headers, json=data)
response = response.json()
if response["status"] == "error":
self.user_dict = {} # assume this means the user dict hasn't been stored yet
self.user_dict = (
{}
) # assume this means the user dict hasn't been stored yet
else:
self.user_dict = response["data"]
def create_budget(self, total_budget: float, user: str, duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, created_at: float = time.time()):
def create_budget(
self,
total_budget: float,
user: str,
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
created_at: float = time.time(),
):
self.user_dict[user] = {"total_budget": total_budget}
if duration is None:
return self.user_dict[user]
if duration == 'daily':
if duration == "daily":
duration_in_days = 1
elif duration == 'weekly':
elif duration == "weekly":
duration_in_days = 7
elif duration == 'monthly':
elif duration == "monthly":
duration_in_days = 28
elif duration == 'yearly':
elif duration == "yearly":
duration_in_days = 365
else:
raise ValueError("""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""")
self.user_dict[user] = {"total_budget": total_budget, "duration": duration_in_days, "created_at": created_at, "last_updated_at": created_at}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
raise ValueError(
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
)
self.user_dict[user] = {
"total_budget": total_budget,
"duration": duration_in_days,
"created_at": created_at,
"last_updated_at": created_at,
}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return self.user_dict[user]
def projected_cost(self, model: str, messages: list, user: str):
text = "".join(message["content"] for message in messages)
prompt_tokens = litellm.token_counter(model=model, text=text)
prompt_cost, _ = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=0)
prompt_cost, _ = litellm.cost_per_token(
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
)
current_cost = self.user_dict[user].get("current_cost", 0)
projected_cost = prompt_cost + current_cost
return projected_cost
def get_total_budget(self, user: str):
return self.user_dict[user]["total_budget"]
def update_cost(self, user: str, completion_obj: Optional[ModelResponse] = None, model: Optional[str] = None, input_text: Optional[str] = None, output_text: Optional[str] = None):
if model and input_text and output_text:
prompt_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": input_text}])
completion_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": output_text}])
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
elif completion_obj:
cost = litellm.completion_cost(completion_response=completion_obj)
model = completion_obj['model'] # if this throws an error try, model = completion_obj['model']
else:
raise ValueError("Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager")
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get("current_cost", 0)
def update_cost(
self,
user: str,
completion_obj: Optional[ModelResponse] = None,
model: Optional[str] = None,
input_text: Optional[str] = None,
output_text: Optional[str] = None,
):
if model and input_text and output_text:
prompt_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": input_text}]
)
completion_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": output_text}]
)
(
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
) = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
elif completion_obj:
cost = litellm.completion_cost(completion_response=completion_obj)
model = completion_obj[
"model"
] # if this throws an error try, model = completion_obj['model']
else:
raise ValueError(
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
)
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
"current_cost", 0
)
if "model_cost" in self.user_dict[user]:
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user]["model_cost"].get(model, 0)
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
"model_cost"
].get(model, 0)
else:
self.user_dict[user]["model_cost"] = {model: cost}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return {"user": self.user_dict[user]}
def get_current_cost(self, user):
return self.user_dict[user].get("current_cost", 0)
def get_model_cost(self, user):
return self.user_dict[user].get("model_cost", 0)
def is_valid_user(self, user: str) -> bool:
return user in self.user_dict
def get_users(self):
return list(self.user_dict.keys())
def reset_cost(self, user):
self.user_dict[user]["current_cost"] = 0
self.user_dict[user]["model_cost"] = {}
return {"user": self.user_dict[user]}
def reset_on_duration(self, user: str):
# Get current and creation time
last_updated_at = self.user_dict[user]["last_updated_at"]
@ -121,38 +168,39 @@ class BudgetManager:
# Convert duration from days to seconds
duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60
# Check if duration has elapsed
if current_time - last_updated_at >= duration_in_seconds:
# Reset cost if duration has elapsed and update the creation time
self.reset_cost(user)
self.user_dict[user]["last_updated_at"] = current_time
self._save_data_thread() # Save the data
def update_budget_all_users(self):
for user in self.get_users():
if "duration" in self.user_dict[user]:
self.reset_on_duration(user)
def _save_data_thread(self):
thread = threading.Thread(target=self.save_data) # [Non-Blocking]: saves data without blocking execution
thread = threading.Thread(
target=self.save_data
) # [Non-Blocking]: saves data without blocking execution
thread.start()
def save_data(self):
if self.client_type == "local":
import json
# save the user dict
with open("user_cost.json", 'w') as json_file:
json.dump(self.user_dict, json_file, indent=4) # Indent for pretty formatting
import json
# save the user dict
with open("user_cost.json", "w") as json_file:
json.dump(
self.user_dict, json_file, indent=4
) # Indent for pretty formatting
return {"status": "success"}
elif self.client_type == "hosted":
url = self.api_base + "/set_budget"
headers = {'Content-Type': 'application/json'}
data = {
'project_name' : self.project_name,
"user_dict": self.user_dict
}
headers = {"Content-Type": "application/json"}
data = {"project_name": self.project_name, "user_dict": self.user_dict}
response = requests.post(url, headers=headers, json=data)
response = response.json()
return response
return response

View file

@ -12,13 +12,15 @@ import time, logging
import json, traceback, ast
from typing import Optional, Literal, List
def print_verbose(print_statement):
try:
if litellm.set_verbose:
print(print_statement) # noqa
print(print_statement) # noqa
except:
pass
class BaseCache:
def set_cache(self, key, value, **kwargs):
raise NotImplementedError
@ -45,13 +47,13 @@ class InMemoryCache(BaseCache):
self.cache_dict.pop(key, None)
return None
original_cached_response = self.cache_dict[key]
try:
try:
cached_response = json.loads(original_cached_response)
except:
except:
cached_response = original_cached_response
return cached_response
return None
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
@ -60,17 +62,18 @@ class InMemoryCache(BaseCache):
class RedisCache(BaseCache):
def __init__(self, host=None, port=None, password=None, **kwargs):
import redis
# if users don't provider one, use the default litellm cache
from ._redis import get_redis_client
redis_kwargs = {}
if host is not None:
if host is not None:
redis_kwargs["host"] = host
if port is not None:
redis_kwargs["port"] = port
if password is not None:
if password is not None:
redis_kwargs["password"] = password
redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs)
@ -88,13 +91,19 @@ class RedisCache(BaseCache):
try:
print_verbose(f"Get Redis Cache: key: {key}")
cached_response = self.redis_client.get(key)
print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}")
print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
)
if cached_response != None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string
try:
cached_response = json.loads(cached_response) # Convert string to dictionary
except:
cached_response = cached_response.decode(
"utf-8"
) # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)
return cached_response
except Exception as e:
@ -105,34 +114,40 @@ class RedisCache(BaseCache):
def flush_cache(self):
self.redis_client.flushall()
class DualCache(BaseCache):
class DualCache(BaseCache):
"""
This updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis.
This updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis.
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
"""
def __init__(self, in_memory_cache: Optional[InMemoryCache] =None, redis_cache: Optional[RedisCache] =None) -> None:
def __init__(
self,
in_memory_cache: Optional[InMemoryCache] = None,
redis_cache: Optional[RedisCache] = None,
) -> None:
super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache
self.in_memory_cache = in_memory_cache or InMemoryCache()
# If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache
def set_cache(self, key, value, **kwargs):
# Update both Redis and in-memory cache
try:
try:
print_verbose(f"set cache: key: {key}; value: {value}")
if self.in_memory_cache is not None:
self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None:
self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e:
except Exception as e:
print_verbose(e)
def get_cache(self, key, **kwargs):
# Try to fetch from in-memory cache first
try:
try:
print_verbose(f"get cache: cache key: {key}")
result = None
if self.in_memory_cache is not None:
@ -141,7 +156,7 @@ class DualCache(BaseCache):
if in_memory_result is not None:
result = in_memory_result
if self.redis_cache is not None:
if self.redis_cache is not None:
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs)
@ -153,25 +168,28 @@ class DualCache(BaseCache):
print_verbose(f"get cache: cache result: {result}")
return result
except Exception as e:
except Exception as e:
traceback.print_exc()
def flush_cache(self):
if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache()
if self.redis_cache is not None:
self.redis_cache.flush_cache()
#### LiteLLM.Completion / Embedding Cache ####
class Cache:
def __init__(
self,
type: Optional[Literal["local", "redis"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"],
**kwargs
self,
type: Optional[Literal["local", "redis"]] = "local",
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
**kwargs,
):
"""
Initializes the cache based on the given type.
@ -200,7 +218,7 @@ class Cache:
litellm.success_callback.append("cache")
if "cache" not in litellm._async_success_callback:
litellm._async_success_callback.append("cache")
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
def get_cache_key(self, *args, **kwargs):
"""
@ -215,18 +233,37 @@ class Cache:
"""
cache_key = ""
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
print_verbose(f"\nReturning preset cache key: {cache_key}")
return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]
embedding_only_kwargs = ["input", "encoding_format"] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
completion_kwargs = [
"model",
"messages",
"temperature",
"top_p",
"n",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
]
embedding_only_kwargs = [
"input",
"encoding_format",
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = completion_kwargs + embedding_only_kwargs
combined_kwargs = completion_kwargs + embedding_only_kwargs
for param in combined_kwargs:
# ignore litellm params here
if param in kwargs:
@ -241,8 +278,8 @@ class Cache:
model_group = metadata.get("model_group", None)
caching_groups = metadata.get("caching_groups", None)
if caching_groups:
for group in caching_groups:
if model_group in group:
for group in caching_groups:
if model_group in group:
caching_group = group
break
if litellm_params is not None:
@ -251,23 +288,34 @@ class Cache:
model_group = metadata.get("model_group", None)
caching_groups = metadata.get("caching_groups", None)
if caching_groups:
for group in caching_groups:
if model_group in group:
for group in caching_groups:
if model_group in group:
caching_group = group
break
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"]
param_value = (
caching_group or model_group or kwargs[param]
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
else:
if kwargs[param] is None:
continue # ignore None params
continue # ignore None params
param_value = kwargs[param]
cache_key+= f"{str(param)}: {str(param_value)}"
cache_key += f"{str(param)}: {str(param_value)}"
print_verbose(f"\nCreated cache key: {cache_key}")
return cache_key
def generate_streaming_content(self, content):
chunk_size = 5 # Adjust the chunk size as needed
for i in range(0, len(content), chunk_size):
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]}
yield {
"choices": [
{
"delta": {
"role": "assistant",
"content": content[i : i + chunk_size],
}
}
]
}
time.sleep(0.02)
def get_cache(self, *args, **kwargs):
@ -319,4 +367,4 @@ class Cache:
pass
async def _async_add_cache(self, result, *args, **kwargs):
self.add_cache(result, *args, **kwargs)
self.add_cache(result, *args, **kwargs)

View file

@ -1,2 +1,2 @@
# from .main import *
# from .server_utils import *
# from .server_utils import *

View file

@ -33,7 +33,7 @@
# llm_model_list: Optional[list] = None
# server_settings: Optional[dict] = None
# set_callbacks() # sets litellm callbacks for logging if they exist in the environment
# set_callbacks() # sets litellm callbacks for logging if they exist in the environment
# if "CONFIG_FILE_PATH" in os.environ:
# llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
@ -44,7 +44,7 @@
# @router.get("/models") # if project requires model list
# def model_list():
# all_models = litellm.utils.get_valid_models()
# if llm_model_list:
# if llm_model_list:
# all_models += llm_model_list
# return dict(
# data=[
@ -79,8 +79,8 @@
# @router.post("/v1/embeddings")
# @router.post("/embeddings")
# async def embedding(request: Request):
# try:
# data = await request.json()
# try:
# data = await request.json()
# # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
# if os.getenv("AUTH_STRATEGY", None) == "DYNAMIC" and "authorization" in request.headers: # if users pass LLM api keys as part of header
# api_key = request.headers.get("authorization")
@ -106,13 +106,13 @@
# data = await request.json()
# server_model = server_settings.get("completion_model", None) if server_settings else None
# data["model"] = server_model or model or data["model"]
# ## CHECK KEYS ##
# ## CHECK KEYS ##
# # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
# # env_validation = litellm.validate_environment(model=data["model"])
# # if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
# # if "authorization" in request.headers:
# # api_key = request.headers.get("authorization")
# # elif "api-key" in request.headers:
# # elif "api-key" in request.headers:
# # api_key = request.headers.get("api-key")
# # print(f"api_key in headers: {api_key}")
# # if " " in api_key:
@ -122,11 +122,11 @@
# # api_key = api_key
# # data["api_key"] = api_key
# # print(f"api_key in data: {api_key}")
# ## CHECK CONFIG ##
# ## CHECK CONFIG ##
# if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]:
# for m in llm_model_list:
# if data["model"] == m["model_name"]:
# for key, value in m["litellm_params"].items():
# for m in llm_model_list:
# if data["model"] == m["model_name"]:
# for key, value in m["litellm_params"].items():
# data[key] = value
# break
# response = litellm.completion(
@ -145,21 +145,21 @@
# @router.post("/router/completions")
# async def router_completion(request: Request):
# global llm_router
# try:
# try:
# data = await request.json()
# if "model_list" in data:
# if "model_list" in data:
# llm_router = litellm.Router(model_list=data.pop("model_list"))
# if llm_router is None:
# if llm_router is None:
# raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
# # openai.ChatCompletion.create replacement
# response = await llm_router.acompletion(model="gpt-3.5-turbo",
# response = await llm_router.acompletion(model="gpt-3.5-turbo",
# messages=[{"role": "user", "content": "Hey, how's it going?"}])
# if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
# return StreamingResponse(data_generator(response), media_type='text/event-stream')
# return response
# except Exception as e:
# except Exception as e:
# error_traceback = traceback.format_exc()
# error_msg = f"{str(e)}\n\n{error_traceback}"
# return {"error": error_msg}
@ -167,11 +167,11 @@
# @router.post("/router/embedding")
# async def router_embedding(request: Request):
# global llm_router
# try:
# try:
# data = await request.json()
# if "model_list" in data:
# if "model_list" in data:
# llm_router = litellm.Router(model_list=data.pop("model_list"))
# if llm_router is None:
# if llm_router is None:
# raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
# response = await llm_router.aembedding(model="gpt-3.5-turbo", # type: ignore
@ -180,7 +180,7 @@
# if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
# return StreamingResponse(data_generator(response), media_type='text/event-stream')
# return response
# except Exception as e:
# except Exception as e:
# error_traceback = traceback.format_exc()
# error_msg = f"{str(e)}\n\n{error_traceback}"
# return {"error": error_msg}
@ -190,4 +190,4 @@
# return "LiteLLM: RUNNING"
# app.include_router(router)
# app.include_router(router)

View file

@ -3,7 +3,7 @@
# import dotenv
# dotenv.load_dotenv() # load env variables
# def print_verbose(print_statement):
# def print_verbose(print_statement):
# pass
# def get_package_version(package_name):
@ -27,32 +27,31 @@
# def set_callbacks():
# ## LOGGING
# if len(os.getenv("SET_VERBOSE", "")) > 0:
# if os.getenv("SET_VERBOSE") == "True":
# if len(os.getenv("SET_VERBOSE", "")) > 0:
# if os.getenv("SET_VERBOSE") == "True":
# litellm.set_verbose = True
# print_verbose("\033[92mLiteLLM: Switched on verbose logging\033[0m")
# else:
# else:
# litellm.set_verbose = False
# ### LANGFUSE
# if (len(os.getenv("LANGFUSE_PUBLIC_KEY", "")) > 0 and len(os.getenv("LANGFUSE_SECRET_KEY", ""))) > 0 or len(os.getenv("LANGFUSE_HOST", "")) > 0:
# litellm.success_callback = ["langfuse"]
# litellm.success_callback = ["langfuse"]
# print_verbose("\033[92mLiteLLM: Switched on Langfuse feature\033[0m")
# ## CACHING
# ## CACHING
# ### REDIS
# # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
# # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
# # print(f"redis host: {os.getenv('REDIS_HOST')}; redis port: {os.getenv('REDIS_PORT')}; password: {os.getenv('REDIS_PASSWORD')}")
# # from litellm.caching import Cache
# # litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
# # print("\033[92mLiteLLM: Switched on Redis caching\033[0m")
# def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'):
# config = {}
# server_settings = {}
# try:
# server_settings = {}
# try:
# if os.path.exists(config_file_path): # type: ignore
# with open(config_file_path, 'r') as file: # type: ignore
# config = yaml.safe_load(file)
@ -63,24 +62,24 @@
# ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
# server_settings = config.get("server_settings", None)
# if server_settings:
# if server_settings:
# server_settings = server_settings
# ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
# litellm_settings = config.get('litellm_settings', None)
# if litellm_settings:
# for key, value in litellm_settings.items():
# if litellm_settings:
# for key, value in litellm_settings.items():
# setattr(litellm, key, value)
# ## MODEL LIST
# model_list = config.get('model_list', None)
# if model_list:
# if model_list:
# router = litellm.Router(model_list=model_list)
# ## ENVIRONMENT VARIABLES
# environment_variables = config.get('environment_variables', None)
# if environment_variables:
# for key, value in environment_variables.items():
# if environment_variables:
# for key, value in environment_variables.items():
# os.environ[key] = value
# return router, model_list, server_settings

View file

@ -16,11 +16,11 @@ from openai import (
RateLimitError,
APIStatusError,
OpenAIError,
APIError,
APITimeoutError,
APIConnectionError,
APIError,
APITimeoutError,
APIConnectionError,
APIResponseValidationError,
UnprocessableEntityError
UnprocessableEntityError,
)
import httpx
@ -32,11 +32,10 @@ class AuthenticationError(AuthenticationError): # type: ignore
self.llm_provider = llm_provider
self.model = model
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
# raise when invalid models passed, example gpt-8
class NotFoundError(NotFoundError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
@ -45,9 +44,7 @@ class NotFoundError(NotFoundError): # type: ignore
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
@ -58,23 +55,21 @@ class BadRequestError(BadRequestError): # type: ignore
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 422
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 408
@ -86,6 +81,7 @@ class Timeout(APITimeoutError): # type: ignore
request=request
) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 429
@ -93,11 +89,10 @@ class RateLimitError(RateLimitError): # type: ignore
self.llm_provider = llm_provider
self.modle = model
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
@ -106,12 +101,13 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
self.model = model
self.llm_provider = llm_provider
super().__init__(
message=self.message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=response
message=self.message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=response,
) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(APIStatusError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 503
@ -119,50 +115,42 @@ class ServiceUnavailableError(APIStatusError): # type: ignore
self.llm_provider = llm_provider
self.model = model
super().__init__(
self.message,
response=response,
body=None
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(APIError): # type: ignore
def __init__(self, status_code, message, llm_provider, model, request: httpx.Request):
self.status_code = status_code
class APIError(APIError): # type: ignore
def __init__(
self, status_code, message, llm_provider, model, request: httpx.Request
):
self.status_code = status_code
self.message = message
self.llm_provider = llm_provider
self.model = model
super().__init__(
self.message,
request=request, # type: ignore
body=None
)
super().__init__(self.message, request=request, body=None) # type: ignore
# raised if an invalid request (not get, delete, put, post) is made
class APIConnectionError(APIConnectionError): # type: ignore
class APIConnectionError(APIConnectionError): # type: ignore
def __init__(self, message, llm_provider, model, request: httpx.Request):
self.message = message
self.llm_provider = llm_provider
self.model = model
self.status_code = 500
super().__init__(
message=self.message,
request=request
)
super().__init__(message=self.message, request=request)
# raised if an invalid request (not get, delete, put, post) is made
class APIResponseValidationError(APIResponseValidationError): # type: ignore
class APIResponseValidationError(APIResponseValidationError): # type: ignore
def __init__(self, message, llm_provider, model):
self.message = message
self.llm_provider = llm_provider
self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request)
super().__init__(
response=response,
body=None,
message=message
)
super().__init__(response=response, body=None, message=message)
class OpenAIError(OpenAIError): # type: ignore
def __init__(self, original_exception):
@ -176,6 +164,7 @@ class OpenAIError(OpenAIError): # type: ignore
)
self.llm_provider = "openai"
class BudgetExceededError(Exception):
def __init__(self, current_cost, max_budget):
self.current_cost = current_cost
@ -183,7 +172,8 @@ class BudgetExceededError(Exception):
message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
super().__init__(message)
## DEPRECATED ##
## DEPRECATED ##
class InvalidRequestError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 400

View file

@ -5,32 +5,33 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self):
pass
def log_pre_api_call(self, model, messages, kwargs):
def log_pre_api_call(self, model, messages, kwargs):
pass
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
pass
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass
def log_success_event(self, kwargs, response_obj, start_time, end_time):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
#### ASYNC ####
#### ASYNC ####
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass
@ -43,81 +44,87 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
#### CALL HOOKS - proxy only ####
#### CALL HOOKS - proxy only ####
"""
Control the modify incoming / outgoung data before calling the model
"""
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal["completion", "embeddings"],
):
pass
async def async_post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
pass
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
try:
try:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["log_event_type"] = "pre_api_call"
callback_func(
kwargs,
)
print_verbose(
f"Custom Logger - model call details: {kwargs}"
)
except:
print_verbose(f"Custom Logger - model call details: {kwargs}")
except:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
try:
async def async_log_input_event(
self, model, messages, kwargs, print_verbose, callback_func
):
try:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["log_event_type"] = "pre_api_call"
await callback_func(
kwargs,
)
print_verbose(
f"Custom Logger - model call details: {kwargs}"
)
except:
print_verbose(f"Custom Logger - model call details: {kwargs}")
except:
traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
def log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
):
# Method definition
try:
kwargs["log_event_type"] = "post_api_call"
callback_func(
kwargs, # kwargs to func
kwargs, # kwargs to func
response_obj,
start_time,
end_time,
)
print_verbose(
f"Custom Logger - final response object: {response_obj}"
)
print_verbose(f"Custom Logger - final response object: {response_obj}")
except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass
async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func):
async def async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
):
# Method definition
try:
kwargs["log_event_type"] = "post_api_call"
await callback_func(
kwargs, # kwargs to func
kwargs, # kwargs to func
response_obj,
start_time,
end_time,
)
print_verbose(
f"Custom Logger - final response object: {response_obj}"
)
print_verbose(f"Custom Logger - final response object: {response_obj}")
except:
# traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass
pass

View file

@ -10,19 +10,28 @@ import datetime, subprocess, sys
import litellm, uuid
from litellm._logging import print_verbose
class DyanmoDBLogger:
# Class variables or attributes
def __init__(self):
# Instance variables
import boto3
self.dynamodb = boto3.resource('dynamodb', region_name=os.environ["AWS_REGION_NAME"])
self.dynamodb = boto3.resource(
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
)
if litellm.dynamodb_table_name is None:
raise ValueError("LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`")
raise ValueError(
"LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`"
)
self.table_name = litellm.dynamodb_table_name
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose
):
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
try:
print_verbose(
@ -32,7 +41,9 @@ class DyanmoDBLogger:
# construct payload to send to DynamoDB
# follows the same params as langfuse.py
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
@ -51,7 +62,7 @@ class DyanmoDBLogger:
"messages": messages,
"response": response_obj,
"usage": usage,
"metadata": metadata
"metadata": metadata,
}
# Ensure everything in the payload is converted to str
@ -62,9 +73,8 @@ class DyanmoDBLogger:
# non blocking if it can't cast to a str
pass
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
# put data in dyanmo DB
table = self.dynamodb.Table(self.table_name)
# Assuming log_data is a dictionary with log information
@ -79,4 +89,4 @@ class DyanmoDBLogger:
except:
traceback.print_exc()
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
pass
pass

View file

@ -64,7 +64,9 @@ class LangFuseLogger:
# end of processing langfuse ########################
input = prompt
output = response_obj["choices"][0]["message"].json()
print(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}")
print(
f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}"
)
self._log_langfuse_v2(
user_id,
metadata,
@ -171,7 +173,6 @@ class LangFuseLogger:
user_id=user_id,
)
trace.generation(
name=metadata.get("generation_name", "litellm-completion"),
startTime=start_time,

View file

@ -8,25 +8,27 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
class LangsmithLogger:
# Class variables or attributes
def __init__(self):
self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
# Method definition
# inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb
metadata = {}
if "litellm_params" in kwargs:
metadata = kwargs["litellm_params"].get("metadata", {})
# set project name and run_name for langsmith logging
# set project name and run_name for langsmith logging
# users can pass project_name and run name to litellm.completion()
# Example: litellm.completion(model, messages, metadata={"project_name": "my-litellm-project", "run_name": "my-langsmith-run"})
# if not set litellm will use default project_name = litellm-completion, run_name = LLMRun
project_name = metadata.get("project_name", "litellm-completion")
run_name = metadata.get("run_name", "LLMRun")
print_verbose(f"Langsmith Logging - project_name: {project_name}, run_name {run_name}")
print_verbose(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
try:
print_verbose(
f"Langsmith Logging - Enters logging function for model {kwargs}"
@ -34,6 +36,7 @@ class LangsmithLogger:
import requests
import datetime
from datetime import timezone
try:
start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat()
end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat()
@ -45,7 +48,7 @@ class LangsmithLogger:
new_kwargs = {}
for key in kwargs:
value = kwargs[key]
if key == "start_time" or key =="end_time":
if key == "start_time" or key == "end_time":
pass
elif type(value) != dict:
new_kwargs[key] = value
@ -54,18 +57,14 @@ class LangsmithLogger:
"https://api.smith.langchain.com/runs",
json={
"name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": {
**new_kwargs
},
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": {**new_kwargs},
"outputs": response_obj.json(),
"session_name": project_name,
"start_time": start_time,
"end_time": end_time,
},
headers={
"x-api-key": self.langsmith_api_key
}
headers={"x-api-key": self.langsmith_api_key},
)
print_verbose(
f"Langsmith Layer Logging - final response object: {response_obj}"

View file

@ -1,6 +1,7 @@
import requests, traceback, json, os
import types
class LiteDebugger:
user_email = None
dashboard_url = None
@ -12,9 +13,15 @@ class LiteDebugger:
def validate_environment(self, email):
try:
self.user_email = (email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL"))
if self.user_email == None: # if users are trying to use_client=True but token not set
raise ValueError("litellm.use_client = True but no token or email passed. Please set it in litellm.token")
self.user_email = (
email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
)
if (
self.user_email == None
): # if users are trying to use_client=True but token not set
raise ValueError(
"litellm.use_client = True but no token or email passed. Please set it in litellm.token"
)
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
try:
print(
@ -42,7 +49,9 @@ class LiteDebugger:
litellm_params,
optional_params,
):
print_verbose(f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}")
print_verbose(
f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}"
)
try:
print_verbose(
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
@ -56,7 +65,11 @@ class LiteDebugger:
updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
if call_type == "embedding":
for message in messages: # assuming the input is a list as required by the embedding function
for (
message
) in (
messages
): # assuming the input is a list as required by the embedding function
litellm_data_obj = {
"model": model,
"messages": [{"role": "user", "content": message}],
@ -79,7 +92,9 @@ class LiteDebugger:
elif call_type == "completion":
litellm_data_obj = {
"model": model,
"messages": messages if isinstance(messages, list) else [{"role": "user", "content": messages}],
"messages": messages
if isinstance(messages, list)
else [{"role": "user", "content": messages}],
"end_user": end_user,
"status": "initiated",
"litellm_call_id": litellm_call_id,
@ -95,20 +110,30 @@ class LiteDebugger:
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
print_verbose(f"LiteDebugger: completion api response - {response.text}")
print_verbose(
f"LiteDebugger: completion api response - {response.text}"
)
except:
print_verbose(
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
)
pass
def post_call_log_event(self, original_response, litellm_call_id, print_verbose, call_type, stream):
print_verbose(f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}")
def post_call_log_event(
self, original_response, litellm_call_id, print_verbose, call_type, stream
):
print_verbose(
f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}"
)
try:
if call_type == "embedding":
litellm_data_obj = {
"status": "received",
"additional_details": {"original_response": str(original_response["data"][0]["embedding"][:5])}, # don't store the entire vector
"additional_details": {
"original_response": str(
original_response["data"][0]["embedding"][:5]
)
}, # don't store the entire vector
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
}
@ -122,7 +147,11 @@ class LiteDebugger:
elif call_type == "completion" and stream:
litellm_data_obj = {
"status": "received",
"additional_details": {"original_response": "Streamed response" if isinstance(original_response, types.GeneratorType) else original_response},
"additional_details": {
"original_response": "Streamed response"
if isinstance(original_response, types.GeneratorType)
else original_response
},
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
}
@ -146,10 +175,12 @@ class LiteDebugger:
end_time,
litellm_call_id,
print_verbose,
call_type,
stream = False
call_type,
stream=False,
):
print_verbose(f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}")
print_verbose(
f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}"
)
try:
print_verbose(
f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}"
@ -186,7 +217,7 @@ class LiteDebugger:
data=json.dumps(litellm_data_obj),
)
elif call_type == "completion" and stream == True:
if len(response_obj["content"]) > 0: # don't log the empty strings
if len(response_obj["content"]) > 0: # don't log the empty strings
litellm_data_obj = {
"response_time": response_time,
"total_cost": total_cost,

View file

@ -18,19 +18,17 @@ class PromptLayerLogger:
# Method definition
try:
new_kwargs = {}
new_kwargs['model'] = kwargs['model']
new_kwargs['messages'] = kwargs['messages']
new_kwargs["model"] = kwargs["model"]
new_kwargs["messages"] = kwargs["messages"]
# add kwargs["optional_params"] to new_kwargs
for optional_param in kwargs["optional_params"]:
new_kwargs[optional_param] = kwargs["optional_params"][optional_param]
print_verbose(
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
)
request_response = requests.post(
"https://api.promptlayer.com/rest/track-request",
json={
@ -51,8 +49,8 @@ class PromptLayerLogger:
f"Prompt Layer Logging: success - final response object: {request_response.text}"
)
response_json = request_response.json()
if "success" not in request_response.json():
raise Exception("Promptlayer did not successfully log the response!")
if "success" not in request_response.json():
raise Exception("Promptlayer did not successfully log the response!")
if "request_id" in response_json:
print(kwargs["litellm_params"]["metadata"])
@ -62,10 +60,12 @@ class PromptLayerLogger:
json={
"request_id": response_json["request_id"],
"api_key": self.key,
"metadata": kwargs["litellm_params"]["metadata"]
"metadata": kwargs["litellm_params"]["metadata"],
},
)
print_verbose(f"Prompt Layer Logging: success - metadata post response object: {response.text}")
print_verbose(
f"Prompt Layer Logging: success - metadata post response object: {response.text}"
)
except:
print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")

View file

@ -9,6 +9,7 @@ import traceback
import datetime, subprocess, sys
import litellm
class Supabase:
# Class variables or attributes
supabase_table_name = "request_logs"

View file

@ -2,6 +2,7 @@ class TraceloopLogger:
def __init__(self):
from traceloop.sdk.tracing.tracing import TracerWrapper
from traceloop.sdk import Traceloop
Traceloop.init(app_name="Litellm-Server", disable_batch=True)
self.tracer_wrapper = TracerWrapper()
@ -29,15 +30,18 @@ class TraceloopLogger:
)
if "stop" in optional_params:
span.set_attribute(
SpanAttributes.LLM_CHAT_STOP_SEQUENCES, optional_params.get("stop")
SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
optional_params.get("stop"),
)
if "frequency_penalty" in optional_params:
span.set_attribute(
SpanAttributes.LLM_FREQUENCY_PENALTY, optional_params.get("frequency_penalty")
SpanAttributes.LLM_FREQUENCY_PENALTY,
optional_params.get("frequency_penalty"),
)
if "presence_penalty" in optional_params:
span.set_attribute(
SpanAttributes.LLM_PRESENCE_PENALTY, optional_params.get("presence_penalty")
SpanAttributes.LLM_PRESENCE_PENALTY,
optional_params.get("presence_penalty"),
)
if "top_p" in optional_params:
span.set_attribute(
@ -45,7 +49,10 @@ class TraceloopLogger:
)
if "tools" in optional_params or "functions" in optional_params:
span.set_attribute(
SpanAttributes.LLM_REQUEST_FUNCTIONS, optional_params.get("tools", optional_params.get("functions"))
SpanAttributes.LLM_REQUEST_FUNCTIONS,
optional_params.get(
"tools", optional_params.get("functions")
),
)
if "user" in optional_params:
span.set_attribute(
@ -53,7 +60,8 @@ class TraceloopLogger:
)
if "max_tokens" in optional_params:
span.set_attribute(
SpanAttributes.LLM_REQUEST_MAX_TOKENS, kwargs.get("max_tokens")
SpanAttributes.LLM_REQUEST_MAX_TOKENS,
kwargs.get("max_tokens"),
)
if "temperature" in optional_params:
span.set_attribute(

View file

@ -1,4 +1,4 @@
imported_openAIResponse=True
imported_openAIResponse = True
try:
import io
import logging
@ -12,15 +12,12 @@ try:
else:
from typing_extensions import Literal, Protocol
logger = logging.getLogger(__name__)
K = TypeVar("K", bound=str)
V = TypeVar("V")
class OpenAIResponse(Protocol[K, V]): # type: ignore
class OpenAIResponse(Protocol[K, V]): # type: ignore
# contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"]
@ -30,7 +27,6 @@ try:
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
... # pragma: no cover
class OpenAIRequestResponseResolver:
def __call__(
self,
@ -44,7 +40,9 @@ try:
elif response["object"] == "text_completion":
return self._resolve_completion(request, response, time_elapsed)
elif response["object"] == "chat.completion":
return self._resolve_chat_completion(request, response, time_elapsed)
return self._resolve_chat_completion(
request, response, time_elapsed
)
else:
logger.info(f"Unknown OpenAI response object: {response['object']}")
except Exception as e:
@ -113,7 +111,8 @@ try:
"""Resolves the request and response objects for `openai.Completion`."""
request_str = f"\n\n**Prompt**: {request['prompt']}\n"
choices = [
f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"]
f"\n\n**Completion**: {choice['text']}\n"
for choice in response["choices"]
]
return self._request_response_result_to_trace(
@ -167,9 +166,9 @@ try:
]
trace = self.results_to_trace_tree(request, response, results, time_elapsed)
return trace
except:
imported_openAIResponse=False
except:
imported_openAIResponse = False
#### What this does ####
@ -182,29 +181,34 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
class WeightsBiasesLogger:
# Class variables or attributes
def __init__(self):
try:
import wandb
except:
raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m")
if imported_openAIResponse==False:
raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m")
raise Exception(
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
)
if imported_openAIResponse == False:
raise Exception(
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
)
self.resolver = OpenAIRequestResponseResolver()
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
# Method definition
import wandb
try:
print_verbose(
f"W&B Logging - Enters logging function for model {kwargs}"
)
print_verbose(f"W&B Logging - Enters logging function for model {kwargs}")
run = wandb.init()
print_verbose(response_obj)
trace = self.resolver(kwargs, response_obj, (end_time-start_time).total_seconds())
trace = self.resolver(
kwargs, response_obj, (end_time - start_time).total_seconds()
)
if trace is not None:
run.log({"trace": trace})

View file

@ -7,76 +7,93 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message
import litellm
class AI21Error(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/")
self.request = httpx.Request(
method="POST", url="https://api.ai21.com/studio/v1/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AI21Config():
class AI21Config:
"""
Reference: https://docs.ai21.com/reference/j2-complete-ref
The class `AI21Config` provides configuration for the AI21's API interface. Below are the parameters:
- `numResults` (int32): Number of completions to sample and return. Optional, default is 1. If the temperature is greater than 0 (non-greedy decoding), a value greater than 1 can be meaningful.
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
- `minTokens` (int32): The minimum number of tokens to generate per result. Optional, default is 0. If `stopSequences` are given, they are ignored until `minTokens` are generated.
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
- `topKReturn` (int32): Range between 0 to 10, including both. Optional, default is 0. Specifies the top-K alternative tokens to return. A non-zero value includes the string representations and log-probabilities for each of the top-K alternatives at each position.
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
- `presencePenalty` (object): Placeholder for presence penalty object.
- `countPenalty` (object): Placeholder for count penalty object.
"""
numResults: Optional[int]=None
maxTokens: Optional[int]=None
minTokens: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
stopSequences: Optional[list]=None
topKReturn: Optional[int]=None
frequencePenalty: Optional[dict]=None
presencePenalty: Optional[dict]=None
countPenalty: Optional[dict]=None
def __init__(self,
numResults: Optional[int]=None,
maxTokens: Optional[int]=None,
minTokens: Optional[int]=None,
temperature: Optional[float]=None,
topP: Optional[float]=None,
stopSequences: Optional[list]=None,
topKReturn: Optional[int]=None,
frequencePenalty: Optional[dict]=None,
presencePenalty: Optional[dict]=None,
countPenalty: Optional[dict]=None) -> None:
numResults: Optional[int] = None
maxTokens: Optional[int] = None
minTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
stopSequences: Optional[list] = None
topKReturn: Optional[int] = None
frequencePenalty: Optional[dict] = None
presencePenalty: Optional[dict] = None
countPenalty: Optional[dict] = None
def __init__(
self,
numResults: Optional[int] = None,
maxTokens: Optional[int] = None,
minTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
topKReturn: Optional[int] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
@ -91,6 +108,7 @@ def validate_environment(api_key):
}
return headers
def completion(
model: str,
messages: list,
@ -110,20 +128,18 @@ def completion(
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
## Load Config
config = litellm.AI21Config.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.AI21Config.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -134,29 +150,26 @@ def completion(
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
api_base + model + "/complete", headers=headers, data=json.dumps(data)
)
if response.status_code != 200:
raise AI21Error(
status_code=response.status_code,
message=response.text
)
raise AI21Error(status_code=response.status_code, message=response.text)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
completion_response = response.json()
try:
@ -164,18 +177,22 @@ def completion(
for idx, item in enumerate(completion_response["completions"]):
if len(item["data"]["text"]) > 0:
message_obj = Message(content=item["data"]["text"])
else:
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finishReason"]["reason"], index=idx+1, message=message_obj)
choice_obj = Choices(
finish_reason=item["finishReason"]["reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise AI21Error(message=traceback.format_exc(), status_code=response.status_code)
raise AI21Error(
message=traceback.format_exc(), status_code=response.status_code
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content"))
)
@ -189,6 +206,7 @@ def completion(
}
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -8,17 +8,21 @@ import litellm
from litellm.utils import ModelResponse, Choices, Message, Usage
import httpx
class AlephAlphaError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.aleph-alpha.com/complete")
self.request = httpx.Request(
method="POST", url="https://api.aleph-alpha.com/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AlephAlphaConfig():
class AlephAlphaConfig:
"""
Reference: https://docs.aleph-alpha.com/api/complete/
@ -42,13 +46,13 @@ class AlephAlphaConfig():
- `repetition_penalties_include_prompt`, `repetition_penalties_include_completion`, `use_multiplicative_presence_penalty`,`use_multiplicative_frequency_penalty`,`use_multiplicative_sequence_penalty` (boolean, nullable; default value: false): Various settings that adjust how the repetition penalties are applied.
- `penalty_bias` (string, nullable): Text used in addition to the penalized tokens for repetition penalties.
- `penalty_bias` (string, nullable): Text used in addition to the penalized tokens for repetition penalties.
- `penalty_exceptions` (string[], nullable): Strings that may be generated without penalty.
- `penalty_exceptions_include_stop_sequences` (boolean, nullable; default value: true): Include all stop_sequences in penalty_exceptions.
- `best_of` (integer, nullable; default value: 1): The number of completions will be generated on the server side.
- `best_of` (integer, nullable; default value: 1): The number of completions will be generated on the server side.
- `n` (integer, nullable; default value: 1): The number of completions to return.
@ -68,87 +72,101 @@ class AlephAlphaConfig():
- `completion_bias_inclusion_first_token_only`, `completion_bias_exclusion_first_token_only` (boolean; default value: false): Consider only the first token for the completion_bias_inclusion/exclusion.
- `contextual_control_threshold` (number, nullable): Control over how similar tokens are controlled.
- `contextual_control_threshold` (number, nullable): Control over how similar tokens are controlled.
- `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
"""
maximum_tokens: Optional[int]=litellm.max_tokens # aleph alpha requires max tokens
minimum_tokens: Optional[int]=None
echo: Optional[bool]=None
temperature: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[int]=None
presence_penalty: Optional[int]=None
frequency_penalty: Optional[int]=None
sequence_penalty: Optional[int]=None
sequence_penalty_min_length: Optional[int]=None
repetition_penalties_include_prompt: Optional[bool]=None
repetition_penalties_include_completion: Optional[bool]=None
use_multiplicative_presence_penalty: Optional[bool]=None
use_multiplicative_frequency_penalty: Optional[bool]=None
use_multiplicative_sequence_penalty: Optional[bool]=None
penalty_bias: Optional[str]=None
penalty_exceptions_include_stop_sequences: Optional[bool]=None
best_of: Optional[int]=None
n: Optional[int]=None
logit_bias: Optional[dict]=None
log_probs: Optional[int]=None
stop_sequences: Optional[list]=None
tokens: Optional[bool]=None
raw_completion: Optional[bool]=None
disable_optimizations: Optional[bool]=None
completion_bias_inclusion: Optional[list]=None
completion_bias_exclusion: Optional[list]=None
completion_bias_inclusion_first_token_only: Optional[bool]=None
completion_bias_exclusion_first_token_only: Optional[bool]=None
contextual_control_threshold: Optional[int]=None
control_log_additive: Optional[bool]=None
maximum_tokens: Optional[
int
] = litellm.max_tokens # aleph alpha requires max tokens
minimum_tokens: Optional[int] = None
echo: Optional[bool] = None
temperature: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
presence_penalty: Optional[int] = None
frequency_penalty: Optional[int] = None
sequence_penalty: Optional[int] = None
sequence_penalty_min_length: Optional[int] = None
repetition_penalties_include_prompt: Optional[bool] = None
repetition_penalties_include_completion: Optional[bool] = None
use_multiplicative_presence_penalty: Optional[bool] = None
use_multiplicative_frequency_penalty: Optional[bool] = None
use_multiplicative_sequence_penalty: Optional[bool] = None
penalty_bias: Optional[str] = None
penalty_exceptions_include_stop_sequences: Optional[bool] = None
best_of: Optional[int] = None
n: Optional[int] = None
logit_bias: Optional[dict] = None
log_probs: Optional[int] = None
stop_sequences: Optional[list] = None
tokens: Optional[bool] = None
raw_completion: Optional[bool] = None
disable_optimizations: Optional[bool] = None
completion_bias_inclusion: Optional[list] = None
completion_bias_exclusion: Optional[list] = None
completion_bias_inclusion_first_token_only: Optional[bool] = None
completion_bias_exclusion_first_token_only: Optional[bool] = None
contextual_control_threshold: Optional[int] = None
control_log_additive: Optional[bool] = None
def __init__(self,
maximum_tokens: Optional[int]=None,
minimum_tokens: Optional[int]=None,
echo: Optional[bool]=None,
temperature: Optional[int]=None,
top_k: Optional[int]=None,
top_p: Optional[int]=None,
presence_penalty: Optional[int]=None,
frequency_penalty: Optional[int]=None,
sequence_penalty: Optional[int]=None,
sequence_penalty_min_length: Optional[int]=None,
repetition_penalties_include_prompt: Optional[bool]=None,
repetition_penalties_include_completion: Optional[bool]=None,
use_multiplicative_presence_penalty: Optional[bool]=None,
use_multiplicative_frequency_penalty: Optional[bool]=None,
use_multiplicative_sequence_penalty: Optional[bool]=None,
penalty_bias: Optional[str]=None,
penalty_exceptions_include_stop_sequences: Optional[bool]=None,
best_of: Optional[int]=None,
n: Optional[int]=None,
logit_bias: Optional[dict]=None,
log_probs: Optional[int]=None,
stop_sequences: Optional[list]=None,
tokens: Optional[bool]=None,
raw_completion: Optional[bool]=None,
disable_optimizations: Optional[bool]=None,
completion_bias_inclusion: Optional[list]=None,
completion_bias_exclusion: Optional[list]=None,
completion_bias_inclusion_first_token_only: Optional[bool]=None,
completion_bias_exclusion_first_token_only: Optional[bool]=None,
contextual_control_threshold: Optional[int]=None,
control_log_additive: Optional[bool]=None) -> None:
def __init__(
self,
maximum_tokens: Optional[int] = None,
minimum_tokens: Optional[int] = None,
echo: Optional[bool] = None,
temperature: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
presence_penalty: Optional[int] = None,
frequency_penalty: Optional[int] = None,
sequence_penalty: Optional[int] = None,
sequence_penalty_min_length: Optional[int] = None,
repetition_penalties_include_prompt: Optional[bool] = None,
repetition_penalties_include_completion: Optional[bool] = None,
use_multiplicative_presence_penalty: Optional[bool] = None,
use_multiplicative_frequency_penalty: Optional[bool] = None,
use_multiplicative_sequence_penalty: Optional[bool] = None,
penalty_bias: Optional[str] = None,
penalty_exceptions_include_stop_sequences: Optional[bool] = None,
best_of: Optional[int] = None,
n: Optional[int] = None,
logit_bias: Optional[dict] = None,
log_probs: Optional[int] = None,
stop_sequences: Optional[list] = None,
tokens: Optional[bool] = None,
raw_completion: Optional[bool] = None,
disable_optimizations: Optional[bool] = None,
completion_bias_inclusion: Optional[list] = None,
completion_bias_exclusion: Optional[list] = None,
completion_bias_inclusion_first_token_only: Optional[bool] = None,
completion_bias_exclusion_first_token_only: Optional[bool] = None,
contextual_control_threshold: Optional[int] = None,
control_log_additive: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
@ -160,6 +178,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -177,9 +196,11 @@ def completion(
headers = validate_environment(api_key)
## Load Config
config = litellm.AlephAlphaConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.AlephAlphaConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
completion_url = api_base
@ -188,21 +209,17 @@ def completion(
if "control" in model: # follow the ###Instruction / ###Response format
for idx, message in enumerate(messages):
if "role" in message:
if idx == 0: # set first message as instruction (required), let later user messages be input
if (
idx == 0
): # set first message as instruction (required), let later user messages be input
prompt += f"###Instruction: {message['content']}"
else:
if message["role"] == "system":
prompt += (
f"###Instruction: {message['content']}"
)
prompt += f"###Instruction: {message['content']}"
elif message["role"] == "user":
prompt += (
f"###Input: {message['content']}"
)
prompt += f"###Input: {message['content']}"
else:
prompt += (
f"###Response: {message['content']}"
)
prompt += f"###Response: {message['content']}"
else:
prompt += f"{message['content']}"
else:
@ -215,24 +232,27 @@ def completion(
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
@ -247,18 +267,23 @@ def completion(
for idx, item in enumerate(completion_response["completions"]):
if len(item["completion"]) > 0:
message_obj = Message(content=item["completion"])
else:
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except:
raise AlephAlphaError(message=json.dumps(completion_response), status_code=response.status_code)
raise AlephAlphaError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
)
@ -268,11 +293,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -5,56 +5,76 @@ import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
import litellm
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
class AnthropicError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.anthropic.com/v1/complete")
self.request = httpx.Request(
method="POST", url="https://api.anthropic.com/v1/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AnthropicConfig():
class AnthropicConfig:
"""
Reference: https://docs.anthropic.com/claude/reference/complete_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens_to_sample: Optional[int]=litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list]=None
temperature: Optional[int]=None
top_p: Optional[int]=None
top_k: Optional[int]=None
metadata: Optional[dict]=None
def __init__(self,
max_tokens_to_sample: Optional[int]=256, # anthropic requires a default
stop_sequences: Optional[list]=None,
temperature: Optional[int]=None,
top_p: Optional[int]=None,
top_k: Optional[int]=None,
metadata: Optional[dict]=None) -> None:
max_tokens_to_sample: Optional[
int
] = litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
metadata: Optional[dict] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# makes headers for API call
@ -71,6 +91,7 @@ def validate_environment(api_key):
}
return headers
def completion(
model: str,
messages: list,
@ -87,21 +108,25 @@ def completion(
):
headers = validate_environment(api_key)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## Load Config
config = litellm.AnthropicConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.AnthropicConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -116,7 +141,7 @@ def completion(
api_key=api_key,
additional_args={"complete_input_dict": data, "api_base": api_base},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
@ -125,18 +150,20 @@ def completion(
data=json.dumps(data),
stream=optional_params["stream"],
)
if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text)
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines()
else:
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
response = requests.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text)
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
@ -159,9 +186,9 @@ def completion(
)
else:
if len(completion_response["completion"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response[
"completion"
]
model_response["choices"][0]["message"][
"content"
] = completion_response["completion"]
model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE
@ -177,11 +204,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -1,7 +1,13 @@
from typing import Optional, Union, Any
import types, requests
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
)
from typing import Callable, Optional
from litellm import OpenAIConfig
import litellm, json
@ -9,8 +15,15 @@ import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
class AzureOpenAIError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request:
@ -20,11 +33,14 @@ class AzureOpenAIError(Exception):
if response:
self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AzureOpenAIConfig(OpenAIConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -49,33 +65,37 @@ class AzureOpenAIConfig(OpenAIConfig):
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]]= None,
functions: Optional[list]= None,
logit_bias: Optional[dict]= None,
max_tokens: Optional[int]= None,
n: Optional[int]= None,
presence_penalty: Optional[int]= None,
stop: Optional[Union[str,list]]=None,
temperature: Optional[int]= None,
top_p: Optional[int]= None) -> None:
super().__init__(frequency_penalty,
function_call,
functions,
logit_bias,
max_tokens,
n,
presence_penalty,
stop,
temperature,
top_p)
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
super().__init__(
frequency_penalty,
function_call,
functions,
logit_bias,
max_tokens,
n,
presence_penalty,
stop,
temperature,
top_p,
)
class AzureChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
@ -89,49 +109,51 @@ class AzureChatCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
def completion(self,
model: str,
messages: list,
model_response: ModelResponse,
api_key: str,
api_base: str,
api_version: str,
api_type: str,
azure_ad_token: str,
print_verbose: Callable,
timeout,
logging_obj,
optional_params,
litellm_params,
logger_fn,
acompletion: bool = False,
headers: Optional[dict]=None,
client = None,
):
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
api_key: str,
api_base: str,
api_version: str,
api_type: str,
azure_ad_token: str,
print_verbose: Callable,
timeout,
logging_obj,
optional_params,
litellm_params,
logger_fn,
acompletion: bool = False,
headers: Optional[dict] = None,
client=None,
):
super().completion()
exception_mapping_worked = False
try:
if model is None or messages is None:
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
raise AzureOpenAIError(
status_code=422, message=f"Missing model or messages"
)
max_retries = optional_params.pop("max_retries", 2)
### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base:
### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if client is None:
if not api_base.endswith("/"):
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
azure_client_params = {
"api_version": api_version,
"base_url": f"{api_base}",
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -142,26 +164,53 @@ class AzureChatCompletion(BaseLLM):
client = AsyncAzureOpenAI(**azure_client_params)
else:
client = AzureOpenAI(**azure_client_params)
data = {"model": None, "messages": messages, **optional_params}
else:
data = {
"model": None,
"messages": messages,
**optional_params
"model": model, # type: ignore
"messages": messages,
**optional_params,
}
else:
data = {
"model": model, # type: ignore
"messages": messages,
**optional_params
}
if acompletion is True:
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else:
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, client=client, logging_obj=logging_obj)
return self.acompletion(
api_base=api_base,
data=data,
model_response=model_response,
api_key=api_key,
api_version=api_version,
model=model,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
logging_obj=logging_obj,
)
elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else:
## LOGGING
logging_obj.pre_call(
@ -169,16 +218,18 @@ class AzureChatCompletion(BaseLLM):
api_key=api_key,
additional_args={
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token
"api_key": api_key,
"azure_ad_token": azure_ad_token,
},
"api_version": api_version,
"api_base": api_base,
"complete_input_dict": data,
},
)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -186,7 +237,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -196,7 +247,7 @@ class AzureChatCompletion(BaseLLM):
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
response = azure_client.chat.completions.create(**data) # type: ignore
response = azure_client.chat.completions.create(**data) # type: ignore
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
@ -209,30 +260,36 @@ class AzureChatCompletion(BaseLLM):
"api_base": api_base,
},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except AzureOpenAIError as e:
return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
except Exception as e:
raise e
async def acompletion(self,
api_key: str,
api_version: str,
model: str,
api_base: str,
data: dict,
timeout: Any,
model_response: ModelResponse,
azure_ad_token: Optional[str]=None,
client = None, # this is the AsyncAzureOpenAI
logging_obj=None,
):
response = None
try:
async def acompletion(
self,
api_key: str,
api_version: str,
model: str,
api_base: str,
data: dict,
timeout: Any,
model_response: ModelResponse,
azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI
logging_obj=None,
):
response = None
try:
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -240,7 +297,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -252,35 +309,46 @@ class AzureChatCompletion(BaseLLM):
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
input=data["messages"],
api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await azure_client.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except AzureOpenAIError as e:
response = await azure_client.chat.completions.create(**data)
return convert_to_model_response_object(
response_object=json.loads(response.model_dump_json()),
model_response_object=model_response,
)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
except Exception as e:
if hasattr(e, "status_code"):
raise e
else:
raise AzureOpenAIError(status_code=500, message=str(e))
def streaming(self,
logging_obj,
api_base: str,
api_key: str,
api_version: str,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str]=None,
client=None,
):
def streaming(
self,
logging_obj,
api_base: str,
api_key: str,
api_version: str,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
client=None,
):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -288,7 +356,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -300,25 +368,36 @@ class AzureChatCompletion(BaseLLM):
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
input=data["messages"],
api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
return streamwrapper
async def async_streaming(self,
logging_obj,
api_base: str,
api_key: str,
api_version: str,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str]=None,
client = None,
):
async def async_streaming(
self,
logging_obj,
api_base: str,
api_key: str,
api_version: str,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
client=None,
):
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -326,39 +405,49 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": data.pop("max_retries", 2),
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params)
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data['messages'],
input=data["messages"],
api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
async def aembedding(
self,
data: dict,
model_response: ModelResponse,
self,
data: dict,
model_response: ModelResponse,
azure_client_params: dict,
api_key: str,
input: list,
api_key: str,
input: list,
client=None,
logging_obj=None
):
logging_obj=None,
):
response = None
try:
try:
if client is None:
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else:
@ -367,50 +456,53 @@ class AzureChatCompletion(BaseLLM):
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
response_type="embedding",
)
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
def embedding(self,
model: str,
input: list,
api_key: str,
api_base: str,
api_version: str,
timeout: float,
logging_obj=None,
model_response=None,
optional_params=None,
azure_ad_token: Optional[str]=None,
client = None,
aembedding=None,
):
def embedding(
self,
model: str,
input: list,
api_key: str,
api_base: str,
api_version: str,
timeout: float,
logging_obj=None,
model_response=None,
optional_params=None,
azure_ad_token: Optional[str] = None,
client=None,
aembedding=None,
):
super().embedding()
exception_mapping_worked = False
if self._client_session is None:
self._client_session = self.create_client_session()
try:
data = {
"model": model,
"input": input,
**optional_params
}
try:
data = {"model": model, "input": input, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -418,7 +510,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -427,119 +519,130 @@ class AzureChatCompletion(BaseLLM):
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token
}
},
)
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": {"api_key": api_key, "azure_ad_token": azure_ad_token},
},
)
if aembedding == True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
response = self.aembedding(
data=data,
input=input,
logging_obj=logging_obj,
api_key=api_key,
model_response=model_response,
azure_client_params=azure_client_params,
)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
## COMPLETION CALL
response = azure_client.embeddings.create(**data) # type: ignore
## COMPLETION CALL
response = azure_client.embeddings.create(**data) # type: ignore
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data, "api_base": api_base},
original_response=response,
)
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data, "api_base": api_base},
original_response=response,
)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore
except AzureOpenAIError as e:
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
except Exception as e:
if exception_mapping_worked:
raise e
else:
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def aimage_generation(
self,
data: dict,
model_response: ModelResponse,
self,
data: dict,
model_response: ModelResponse,
azure_client_params: dict,
api_key: str,
input: list,
api_key: str,
input: list,
client=None,
logging_obj=None
):
logging_obj=None,
):
response = None
try:
try:
if client is None:
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),)
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params)
client_session = litellm.aclient_session or httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
)
openai_aclient = AsyncAzureOpenAI(
http_client=client_session, **azure_client_params
)
else:
openai_aclient = client
response = await openai_aclient.images.generate(**data)
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation")
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
response_type="image_generation",
)
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
def image_generation(self,
prompt: str,
timeout: float,
model: Optional[str]=None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
azure_ad_token: Optional[str]=None,
logging_obj=None,
optional_params=None,
client=None,
aimg_generation=None,
):
def image_generation(
self,
prompt: str,
timeout: float,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
azure_ad_token: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None,
aimg_generation=None,
):
exception_mapping_worked = False
try:
try:
if model and len(model) > 0:
model = model
else:
model = None
data = {
"model": model,
"prompt": prompt,
**optional_params
}
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
@ -547,39 +650,47 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
return response
if client is None:
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),)
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
client_session = litellm.client_session or httpx.Client(
transport=CustomHTTPTransport(),
)
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data},
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": False,
"complete_input_dict": data,
},
)
## COMPLETION CALL
response = azure_client.images.generate(**data) # type: ignore
response = azure_client.images.generate(**data) # type: ignore
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
# return response
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
except AzureOpenAIError as e:
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
except Exception as e:
if exception_mapping_worked:
raise e
else:
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())

View file

@ -1,47 +1,45 @@
## This is a template base class to be used for adding new LLM providers via API calls
import litellm
import litellm
import httpx, certifi, ssl
from typing import Optional
class BaseLLM:
_client_session: Optional[httpx.Client] = None
def create_client_session(self):
if litellm.client_session:
if litellm.client_session:
_client_session = litellm.client_session
else:
else:
_client_session = httpx.Client()
return _client_session
def create_aclient_session(self):
if litellm.aclient_session:
if litellm.aclient_session:
_aclient_session = litellm.aclient_session
else:
else:
_aclient_session = httpx.AsyncClient()
return _aclient_session
def __exit__(self):
if hasattr(self, '_client_session'):
if hasattr(self, "_client_session"):
self._client_session.close()
async def __aexit__(self, exc_type, exc_val, exc_tb):
if hasattr(self, '_aclient_session'):
if hasattr(self, "_aclient_session"):
await self._aclient_session.aclose()
def validate_environment(self): # set up the environment required to run the model
pass
def completion(
self,
*args,
**kwargs
self, *args, **kwargs
): # logic for parsing in - calling - parsing out model completion calls
pass
def embedding(
self,
*args,
**kwargs
self, *args, **kwargs
): # logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -6,6 +6,7 @@ import time
from typing import Callable
from litellm.utils import ModelResponse, Usage
class BasetenError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -14,6 +15,7 @@ class BasetenError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
def validate_environment(api_key):
headers = {
"accept": "application/json",
@ -23,6 +25,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Api-Key {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -52,32 +55,38 @@ def completion(
"inputs": prompt,
"prompt": prompt,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
completion_url_fragment_1 + model + completion_url_fragment_2,
headers=headers,
data=json.dumps(data),
stream=True if "stream" in optional_params and optional_params["stream"] == True else False
stream=True
if "stream" in optional_params and optional_params["stream"] == True
else False,
)
if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True):
if "text/event-stream" in response.headers["Content-Type"] or (
"stream" in optional_params and optional_params["stream"] == True
):
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
@ -91,9 +100,7 @@ def completion(
if (
isinstance(completion_response["model_output"], dict)
and "data" in completion_response["model_output"]
and isinstance(
completion_response["model_output"]["data"], list
)
and isinstance(completion_response["model_output"]["data"], list)
):
model_response["choices"][0]["message"][
"content"
@ -112,12 +119,19 @@ def completion(
if "generated_text" not in completion_response:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code
status_code=response.status_code,
)
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
@ -125,7 +139,7 @@ def completion(
else:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code
status_code=response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
@ -139,11 +153,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, get_secret, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx
class BedrockError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock")
self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AmazonTitanConfig():
class AmazonTitanConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
@ -29,29 +33,44 @@ class AmazonTitanConfig():
- `temperature` (float) temperature for model,
- `topP` (int) top p for model
"""
maxTokenCount: Optional[int]=None
stopSequences: Optional[list]=None
temperature: Optional[float]=None
topP: Optional[int]=None
def __init__(self,
maxTokenCount: Optional[int]=None,
stopSequences: Optional[list]=None,
temperature: Optional[float]=None,
topP: Optional[int]=None) -> None:
maxTokenCount: Optional[int] = None
stopSequences: Optional[list] = None
temperature: Optional[float] = None
topP: Optional[int] = None
def __init__(
self,
maxTokenCount: Optional[int] = None,
stopSequences: Optional[list] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAnthropicConfig():
class AmazonAnthropicConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
@ -64,33 +83,48 @@ class AmazonAnthropicConfig():
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
"""
max_tokens_to_sample: Optional[int]=litellm.max_tokens
stop_sequences: Optional[list]=None
temperature: Optional[float]=None
top_k: Optional[int]=None
top_p: Optional[int]=None
anthropic_version: Optional[str]=None
def __init__(self,
max_tokens_to_sample: Optional[int]=None,
stop_sequences: Optional[list]=None,
temperature: Optional[float]=None,
top_k: Optional[int]=None,
top_p: Optional[int]=None,
anthropic_version: Optional[str]=None) -> None:
max_tokens_to_sample: Optional[int] = litellm.max_tokens
stop_sequences: Optional[list] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
anthropic_version: Optional[str] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = None,
stop_sequences: Optional[list] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonCohereConfig():
class AmazonCohereConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
@ -100,79 +134,110 @@ class AmazonCohereConfig():
- `temperature` (float) model temperature,
- `return_likelihood` (string) n/a
"""
max_tokens: Optional[int]=None
temperature: Optional[float]=None
return_likelihood: Optional[str]=None
def __init__(self,
max_tokens: Optional[int]=None,
temperature: Optional[float]=None,
return_likelihood: Optional[str]=None) -> None:
max_tokens: Optional[int] = None
temperature: Optional[float] = None
return_likelihood: Optional[str] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
return_likelihood: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAI21Config():
class AmazonAI21Config:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
Supported Params for the Amazon / AI21 models:
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
- `presencePenalty` (object): Placeholder for presence penalty object.
- `countPenalty` (object): Placeholder for count penalty object.
"""
maxTokens: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
stopSequences: Optional[list]=None
frequencePenalty: Optional[dict]=None
presencePenalty: Optional[dict]=None
countPenalty: Optional[dict]=None
def __init__(self,
maxTokens: Optional[int]=None,
temperature: Optional[float]=None,
topP: Optional[float]=None,
stopSequences: Optional[list]=None,
frequencePenalty: Optional[dict]=None,
presencePenalty: Optional[dict]=None,
countPenalty: Optional[dict]=None) -> None:
maxTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
stopSequences: Optional[list] = None
frequencePenalty: Optional[dict] = None
presencePenalty: Optional[dict] = None
countPenalty: Optional[dict] = None
def __init__(
self,
maxTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
class AmazonLlamaConfig():
class AmazonLlamaConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
@ -182,48 +247,72 @@ class AmazonLlamaConfig():
- `temperature` (float) temperature for model,
- `top_p` (float) top p for model
"""
max_gen_len: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
def __init__(self,
maxTokenCount: Optional[int]=None,
temperature: Optional[float]=None,
topP: Optional[int]=None) -> None:
max_gen_len: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
def __init__(
self,
maxTokenCount: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def init_bedrock_client(
region_name = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] =None,
aws_bedrock_runtime_endpoint: Optional[str]=None,
):
region_name=None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_bedrock_runtime_endpoint: Optional[str] = None,
):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint]
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith('os.environ/'):
if param and param.startswith("os.environ/"):
params_to_check[i] = get_secret(param)
# Assign updated values back to parameters
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
) = params_to_check
if region_name:
pass
elif aws_region_name:
@ -233,7 +322,10 @@ def init_bedrock_client(
elif standard_aws_region_name:
region_name = standard_aws_region_name
else:
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401)
raise BedrockError(
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
status_code=401,
)
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
@ -242,9 +334,10 @@ def init_bedrock_client(
elif env_aws_bedrock_runtime_endpoint:
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f'https://bedrock-runtime.{region_name}.amazonaws.com'
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
import boto3
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
@ -257,7 +350,7 @@ def init_bedrock_client(
endpoint_url=endpoint_url,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automatically reads env variables
client = boto3.client(
@ -276,25 +369,23 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic")
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
else:
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
return prompt
@ -309,17 +400,18 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
def completion(
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
exception_mapping_worked = False
try:
@ -327,7 +419,9 @@ def completion(
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop(
@ -343,67 +437,71 @@ def completion(
model = model
provider = model.split(".")[0]
prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict)
prompt = convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False)
if provider == "anthropic":
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"prompt": prompt,
**inference_params
})
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.AmazonAI21Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"prompt": prompt,
**inference_params
})
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "cohere":
## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if optional_params.get("stream", False) == True:
inference_params["stream"] = True # cohere requires stream = True in inference params
data = json.dumps({
"prompt": prompt,
**inference_params
})
inference_params[
"stream"
] = True # cohere requires stream = True in inference params
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "meta":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"prompt": prompt,
**inference_params
})
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "amazon": # amazon titan
## LOAD CONFIG
config = litellm.AmazonTitanConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.AmazonTitanConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"inputText": prompt,
"textGenerationConfig": inference_params,
})
data = json.dumps(
{
"inputText": prompt,
"textGenerationConfig": inference_params,
}
)
## COMPLETION CALL
accept = 'application/json'
contentType = 'application/json'
accept = "application/json"
contentType = "application/json"
if stream == True:
if provider == "ai21":
## LOGGING
@ -418,17 +516,17 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
response = client.invoke_model(
body=data,
modelId=model,
accept=accept,
contentType=contentType
body=data, modelId=model, accept=accept, contentType=contentType
)
response = response.get('body').read()
response = response.get("body").read()
return response
else:
## LOGGING
@ -441,20 +539,20 @@ def completion(
)
"""
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
response = client.invoke_model_with_response_stream(
body=data,
modelId=model,
accept=accept,
contentType=contentType
body=data, modelId=model, accept=accept, contentType=contentType
)
response = response.get('body')
response = response.get("body")
return response
try:
try:
## LOGGING
request_str = f"""
response = client.invoke_model(
@ -465,20 +563,20 @@ def completion(
)
"""
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
)
response = client.invoke_model(
body=data,
modelId=model,
accept=accept,
contentType=contentType
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
except Exception as e:
response = client.invoke_model(
body=data, modelId=model, accept=accept, contentType=contentType
)
except Exception as e:
raise BedrockError(status_code=500, message=str(e))
response_body = json.loads(response.get('body').read())
response_body = json.loads(response.get("body").read())
## LOGGING
logging_obj.post_call(
@ -491,16 +589,16 @@ def completion(
## RESPONSE OBJECT
outputText = "default"
if provider == "ai21":
outputText = response_body.get('completions')[0].get('data').get('text')
outputText = response_body.get("completions")[0].get("data").get("text")
elif provider == "anthropic":
outputText = response_body['completion']
outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"]
elif provider == "cohere":
elif provider == "cohere":
outputText = response_body["generations"][0]["text"]
elif provider == "meta":
elif provider == "meta":
outputText = response_body["generation"]
else: # amazon titan
outputText = response_body.get('results')[0].get('outputText')
outputText = response_body.get("results")[0].get("outputText")
response_metadata = response.get("ResponseMetadata", {})
if response_metadata.get("HTTPStatusCode", 500) >= 400:
@ -513,12 +611,13 @@ def completion(
if len(outputText) > 0:
model_response["choices"][0]["message"]["content"] = outputText
except:
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500))
raise BedrockError(
message=json.dumps(outputText),
status_code=response_metadata.get("HTTPStatusCode", 500),
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -528,41 +627,47 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens = prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
except BedrockError as e:
exception_mapping_worked = True
raise e
except Exception as e:
except Exception as e:
if exception_mapping_worked:
raise e
else:
else:
import traceback
raise BedrockError(status_code=500, message=traceback.format_exc())
def _embedding_func_single(
model: str,
input: str,
client: Any,
optional_params=None,
encoding=None,
logging_obj=None,
model: str,
input: str,
client: Any,
optional_params=None,
encoding=None,
logging_obj=None,
):
# logic for parsing in - calling - parsing out model embedding calls
## FORMAT EMBEDDING INPUT ##
## FORMAT EMBEDDING INPUT ##
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop("user", None) # make sure user is not passed in for bedrock call
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
if provider == "amazon":
input = input.replace(os.linesep, " ")
data = {"inputText": input, **inference_params}
# data = json.dumps(data)
elif provider == "cohere":
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
data = {"texts": [input], **inference_params} # type: ignore
body = json.dumps(data).encode("utf-8")
inference_params["input_type"] = inference_params.get(
"input_type", "search_document"
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
data = {"texts": [input], **inference_params} # type: ignore
body = json.dumps(data).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_model(
@ -570,12 +675,14 @@ def _embedding_func_single(
modelId={model},
accept="*/*",
contentType="application/json",
)""" # type: ignore
)""" # type: ignore
logging_obj.pre_call(
input=input,
api_key="", # boto3 is used for init.
additional_args={"complete_input_dict": {"model": model,
"texts": input}, "request_str": request_str},
api_key="", # boto3 is used for init.
additional_args={
"complete_input_dict": {"model": model, "texts": input},
"request_str": request_str,
},
)
try:
response = client.invoke_model(
@ -587,11 +694,11 @@ def _embedding_func_single(
response_body = json.loads(response.get("body").read())
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=json.dumps(response_body),
)
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=json.dumps(response_body),
)
if provider == "cohere":
response = response_body.get("embeddings")
# flatten list
@ -600,7 +707,10 @@ def _embedding_func_single(
elif provider == "amazon":
return response_body.get("embedding")
except Exception as e:
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
raise BedrockError(
message=f"Embedding Error with model {model}: {e}", status_code=500
)
def embedding(
model: str,
@ -616,7 +726,9 @@ def embedding(
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
@ -624,11 +736,19 @@ def embedding(
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
)
)
## Embedding Call
embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls
embeddings = [
_embedding_func_single(
model,
i,
optional_params=optional_params,
client=client,
logging_obj=logging_obj,
)
for i in input
] # [TODO]: make these parallel calls
## Populate OpenAI compliant dictionary
embedding_response = []
@ -647,13 +767,11 @@ def embedding(
input_str = "".join(input)
input_tokens+=len(encoding.encode(input_str))
input_tokens += len(encoding.encode(input_str))
usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens + 0
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0
)
model_response.usage = usage
return model_response

View file

@ -8,88 +8,106 @@ from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx
class CohereError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/generate")
self.request = httpx.Request(
method="POST", url="https://api.cohere.ai/v1/generate"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class CohereConfig():
class CohereConfig:
"""
Reference: https://docs.cohere.com/reference/generate
The class `CohereConfig` provides configuration for the Cohere's API interface. Below are the parameters:
- `num_generations` (integer): Maximum number of generations returned. Default is 1, with a minimum value of 1 and a maximum value of 5.
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default value is 20.
- `truncate` (string): Specifies how the API handles inputs longer than maximum token length. Options include NONE, START, END. Default is END.
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.75.
- `preset` (string): Identifier of a custom preset, a combination of parameters such as prompt, temperature etc.
- `end_sequences` (array of strings): The generated text gets cut at the beginning of the earliest occurrence of an end sequence, which will be excluded from the text.
- `stop_sequences` (array of strings): The generated text gets cut at the end of the earliest occurrence of a stop sequence, which will be included in the text.
- `k` (integer): Limits generation at each step to top `k` most likely tokens. Default is 0.
- `p` (number): Limits generation at each step to most likely tokens with total probability mass of `p`. Default is 0.
- `frequency_penalty` (number): Reduces repetitiveness of generated tokens. Higher values apply stronger penalties to previously occurred tokens.
- `presence_penalty` (number): Reduces repetitiveness of generated tokens. Similar to frequency_penalty, but this penalty applies equally to all tokens that have already appeared.
- `return_likelihoods` (string): Specifies how and if token likelihoods are returned with the response. Options include GENERATION, ALL and NONE.
- `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233}
"""
num_generations: Optional[int]=None
max_tokens: Optional[int]=None
truncate: Optional[str]=None
temperature: Optional[int]=None
preset: Optional[str]=None
end_sequences: Optional[list]=None
stop_sequences: Optional[list]=None
k: Optional[int]=None
p: Optional[int]=None
frequency_penalty: Optional[int]=None
presence_penalty: Optional[int]=None
return_likelihoods: Optional[str]=None
logit_bias: Optional[dict]=None
def __init__(self,
num_generations: Optional[int]=None,
max_tokens: Optional[int]=None,
truncate: Optional[str]=None,
temperature: Optional[int]=None,
preset: Optional[str]=None,
end_sequences: Optional[list]=None,
stop_sequences: Optional[list]=None,
k: Optional[int]=None,
p: Optional[int]=None,
frequency_penalty: Optional[int]=None,
presence_penalty: Optional[int]=None,
return_likelihoods: Optional[str]=None,
logit_bias: Optional[dict]=None) -> None:
num_generations: Optional[int] = None
max_tokens: Optional[int] = None
truncate: Optional[str] = None
temperature: Optional[int] = None
preset: Optional[str] = None
end_sequences: Optional[list] = None
stop_sequences: Optional[list] = None
k: Optional[int] = None
p: Optional[int] = None
frequency_penalty: Optional[int] = None
presence_penalty: Optional[int] = None
return_likelihoods: Optional[str] = None
logit_bias: Optional[dict] = None
def __init__(
self,
num_generations: Optional[int] = None,
max_tokens: Optional[int] = None,
truncate: Optional[str] = None,
temperature: Optional[int] = None,
preset: Optional[str] = None,
end_sequences: Optional[list] = None,
stop_sequences: Optional[list] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
return_likelihoods: Optional[str] = None,
logit_bias: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
headers = {
@ -100,6 +118,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -119,9 +138,11 @@ def completion(
prompt = " ".join(message["content"] for message in messages)
## Load Config
config=litellm.CohereConfig.get_config()
config = litellm.CohereConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -132,16 +153,23 @@ def completion(
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url},
)
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
## error handling for cohere calls
if response.status_code!=200:
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] == True:
@ -149,11 +177,11 @@ def completion(
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
@ -168,18 +196,22 @@ def completion(
for idx, item in enumerate(completion_response["generations"]):
if len(item["text"]) > 0:
message_obj = Message(content=item["text"])
else:
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise CohereError(message=response.text, status_code=response.status_code)
raise CohereError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -189,11 +221,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding(
model: str,
input: list,
@ -206,11 +239,7 @@ def embedding(
headers = validate_environment(api_key)
embed_url = "https://api.cohere.ai/v1/embed"
model = model
data = {
"model": model,
"texts": input,
**optional_params
}
data = {"model": model, "texts": input, **optional_params}
if "3" in model and "input_type" not in data:
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
@ -218,21 +247,19 @@ def embedding(
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
embed_url, headers=headers, data=json.dumps(data)
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
"""
response
{
@ -244,30 +271,23 @@ def embedding(
'usage'
}
"""
if response.status_code!=200:
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
embeddings = response.json()['embeddings']
embeddings = response.json()["embeddings"]
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
input_tokens += len(encoding.encode(text))
model_response["usage"] = {
"prompt_tokens": input_tokens,
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response

View file

@ -1,20 +1,24 @@
import time, json, httpx, asyncio
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
"""
Async implementation of custom http transport
"""
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
"2023-10-01-preview",
]:
request.url = request.url.copy_with(path="/openai/images/generations:submit")
request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
response = await super().handle_async_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
@ -26,7 +30,12 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
timeout = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
return httpx.Response(
status_code=400,
headers=response.headers,
@ -56,26 +65,30 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
)
return await super().handle_async_request(request)
class CustomHTTPTransport(httpx.HTTPTransport):
"""
This class was written as a workaround to support dall-e-2 on openai > v1.x
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
"""
def handle_request(
self,
request: httpx.Request,
) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
"2023-10-01-preview",
]:
request.url = request.url.copy_with(path="/openai/images/generations:submit")
request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
response = super().handle_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
@ -87,7 +100,12 @@ class CustomHTTPTransport(httpx.HTTPTransport):
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}}
timeout = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
return httpx.Response(
status_code=400,
headers=response.headers,
@ -115,4 +133,4 @@ class CustomHTTPTransport(httpx.HTTPTransport):
content=json.dumps(result).encode("utf-8"),
request=request,
)
return super().handle_request(request)
return super().handle_request(request)

View file

@ -8,17 +8,22 @@ import litellm
import sys, httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class GeminiError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat")
self.request = httpx.Request(
method="POST",
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class GeminiConfig():
class GeminiConfig:
"""
Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig
@ -37,33 +42,44 @@ class GeminiConfig():
- `top_k` (int): Optional. The maximum number of tokens to consider when sampling.
"""
candidate_count: Optional[int]=None
stop_sequences: Optional[list]=None
max_output_tokens: Optional[int]=None
temperature: Optional[float]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
max_output_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
def __init__(self,
candidate_count: Optional[int]=None,
stop_sequences: Optional[list]=None,
max_output_tokens: Optional[int]=None,
temperature: Optional[float]=None,
top_p: Optional[float]=None,
top_k: Optional[int]=None) -> None:
def __init__(
self,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion(
@ -83,42 +99,50 @@ def completion(
try:
import google.generativeai as genai
except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai")
raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
genai.configure(api_key=api_key)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="gemini")
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="gemini"
)
## Load Config
inference_params = copy.deepcopy(optional_params)
inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py
config = litellm.GeminiConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params.pop(
"stream", None
) # palm does not support streaming, so we handle this by fake streaming in main.py
config = litellm.GeminiConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": {"inference_params": inference_params}},
)
input=prompt,
api_key="",
additional_args={"complete_input_dict": {"inference_params": inference_params}},
)
## COMPLETION CALL
try:
_model = genai.GenerativeModel(f'models/{model}')
response = _model.generate_content(contents=prompt, generation_config=genai.types.GenerationConfig(**inference_params))
try:
_model = genai.GenerativeModel(f"models/{model}")
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
)
except Exception as e:
raise GeminiError(
message=str(e),
@ -127,11 +151,11 @@ def completion(
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": {}},
)
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": {}},
)
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
completion_response = response
@ -142,31 +166,34 @@ def completion(
message_obj = Message(content=item.content.parts[0].text)
else:
message_obj = Message(content=None)
choice_obj = Choices(index=idx+1, message=message_obj)
choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
traceback.print_exc()
raise GeminiError(message=traceback.format_exc(), status_code=response.status_code)
try:
raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
)
try:
completion_response = model_response["choices"][0]["message"].get("content")
except:
raise GeminiError(status_code=400, message=f"No response received. Original response - {response}")
raise GeminiError(
status_code=400,
message=f"No response received. Original response - {response}",
)
## CALCULATING USAGE
prompt_str = ""
prompt_str = ""
for m in messages:
if isinstance(m["content"], str):
prompt_str += m["content"]
elif isinstance(m["content"], list):
for content in m["content"]:
if content["type"] == "text":
if content["type"] == "text":
prompt_str += content["text"]
prompt_tokens = len(
encoding.encode(prompt_str)
)
prompt_tokens = len(encoding.encode(prompt_str))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -174,13 +201,14 @@ def completion(
model_response["created"] = int(time.time())
model_response["model"] = "gemini/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -11,32 +11,47 @@ from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper,
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request is not None:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models")
else:
self.request = httpx.Request(
method="POST", url="https://api-inference.huggingface.co/models"
)
if response is not None:
self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class HuggingfaceConfig():
class HuggingfaceConfig:
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
"""
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = False # by default don't return the input as part of the output
return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
seed: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
@ -46,50 +61,66 @@ class HuggingfaceConfig():
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None
) -> None:
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def output_parser(generated_text: str):
def output_parser(generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
for token in chat_template_tokens:
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
tgi_models_cache = None
conv_models_cache = None
def read_tgi_conv_models():
try:
global tgi_models_cache, conv_models_cache
@ -101,30 +132,38 @@ def read_tgi_conv_models():
tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__))
# Construct the file path relative to the script's directory
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_text_generation_models.txt")
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
for line in file:
tgi_models.add(line.strip())
# Cache the set for future use
tgi_models_cache = tgi_models
# If not, read the file and populate the cache
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_conversational_models.txt")
file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_conversational_models.txt",
)
conv_models = set()
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
for line in file:
conv_models.add(line.strip())
# Cache the set for future use
conv_models_cache = conv_models
conv_models_cache = conv_models
return tgi_models, conv_models
except:
return set(), set()
def get_hf_task_for_model(model):
# read text file, cast it to set
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models:
@ -134,9 +173,10 @@ def get_hf_task_for_model(model):
elif "roneneldan/TinyStories" in model:
return None
else:
return "text-generation-inference" # default to tgi
return "text-generation-inference" # default to tgi
class Huggingface(BaseLLM):
class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
@ -148,65 +188,93 @@ class Huggingface(BaseLLM):
"content-type": "application/json",
}
if api_key and headers is None:
default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = default_headers
elif headers:
headers=headers
else:
headers = headers
else:
headers = default_headers
return headers
def convert_to_model_response_object(self,
completion_response,
model_response,
task,
optional_params,
encoding,
input_text,
model):
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
def convert_to_model_response_object(
self,
completion_response,
model_response,
task,
optional_params,
encoding,
input_text,
model,
):
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"] # type: ignore
elif task == "text-generation-inference":
if (not isinstance(completion_response, list)
] = completion_response[
"generated_text"
] # type: ignore
elif task == "text-generation-inference":
if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]):
raise HuggingfaceError(status_code=422, message=f"response is not in expected format - {completion_response}")
or "generated_text" not in completion_response[0]
):
raise HuggingfaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
)
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = output_parser(completion_response[0]["generated_text"])
## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser(
completion_response[0]["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprob = sum_logprob
if "best_of" in optional_params and optional_params["best_of"] > 1:
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]:
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = []
for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]):
for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob)
else:
if len(item["generated_text"]) > 0:
message_obj = Message(
content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
else:
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][
"content"
] = output_parser(completion_response[0]["generated_text"])
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser(
completion_response[0]["generated_text"]
)
## CALCULATING USAGE
prompt_tokens = 0
try:
@ -221,12 +289,14 @@ class Huggingface(BaseLLM):
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
else:
completion_tokens = 0
model_response["created"] = int(time.time())
@ -234,13 +304,14 @@ class Huggingface(BaseLLM):
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
model_response._hidden_params["original_response"] = completion_response
return model_response
def completion(self,
def completion(
self,
model: str,
messages: list,
api_base: Optional[str],
@ -276,9 +347,11 @@ class Huggingface(BaseLLM):
completion_url = f"https://api-inference.huggingface.co/models/{model}"
## Load Config
config=litellm.HuggingfaceConfig.get_config()
config = litellm.HuggingfaceConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
### MAP INPUT PARAMS
@ -298,11 +371,11 @@ class Huggingface(BaseLLM):
generated_responses.append(message["content"])
data = {
"inputs": {
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses
"text": text,
"past_user_inputs": past_user_inputs,
"generated_responses": generated_responses,
},
"parameters": inference_params
"parameters": inference_params,
}
input_text = "".join(message["content"] for message in messages)
elif task == "text-generation-inference":
@ -311,29 +384,39 @@ class Huggingface(BaseLLM):
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"inputs": prompt,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
}
input_text = prompt
else:
# Non TGI and Conversational llms
# We need this branch, it removes 'details' and 'return_full_text' from params
# We need this branch, it removes 'details' and 'return_full_text' from params
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
@ -346,52 +429,68 @@ class Huggingface(BaseLLM):
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=input_text,
api_key=api_key,
additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url, "acompletion": acompletion},
)
input=input_text,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"task": task,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if acompletion is True:
### ASYNC STREAMING
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
else:
### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"]
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
return response.iter_lines()
### SYNC COMPLETION
else:
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data)
completion_url, headers=headers, data=json.dumps(data)
)
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
is_streamed = False
if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream":
is_streamed = False
if (
response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True
# iterate over the complete streamed response, and return the final answer
if is_streamed:
streamed_response = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="huggingface", logging_obj=logging_obj)
streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = ""
for chunk in streamed_response:
for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
completion_response: List[Dict[str, Any]] = [
{"generated_text": content}
]
## LOGGING
logging_obj.post_call(
input=input_text,
@ -399,7 +498,7 @@ class Huggingface(BaseLLM):
original_response=completion_response,
additional_args={"complete_input_dict": data, "task": task},
)
else:
else:
## LOGGING
logging_obj.post_call(
input=input_text,
@ -410,15 +509,20 @@ class Huggingface(BaseLLM):
## RESPONSE OBJECT
try:
completion_response = response.json()
if isinstance(completion_response, dict):
if isinstance(completion_response, dict):
completion_response = [completion_response]
except:
import traceback
raise HuggingfaceError(
message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}", status_code=response.status_code
message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}",
status_code=response.status_code,
)
print_verbose(f"response: {completion_response}")
if isinstance(completion_response, dict) and "error" in completion_response:
if (
isinstance(completion_response, dict)
and "error" in completion_response
):
print_verbose(f"completion error: {completion_response['error']}")
print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError(
@ -432,75 +536,98 @@ class Huggingface(BaseLLM):
optional_params=optional_params,
encoding=encoding,
input_text=input_text,
model=model
model=model,
)
except HuggingfaceError as e:
except HuggingfaceError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
except Exception as e:
if exception_mapping_worked:
raise e
else:
else:
import traceback
raise HuggingfaceError(status_code=500, message=traceback.format_exc())
async def acompletion(self,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
task: str,
encoding: Any,
input_text: str,
model: str,
optional_params: dict):
response = None
try:
async def acompletion(
self,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
task: str,
encoding: Any,
input_text: str,
model: str,
optional_params: dict,
):
response = None
try:
async with httpx.AsyncClient() as client:
response = await client.post(url=api_base, json=data, headers=headers, timeout=None)
response = await client.post(
url=api_base, json=data, headers=headers, timeout=None
)
response_json = response.json()
if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
raise HuggingfaceError(
status_code=response.status_code,
message=response.text,
request=response.request,
response=response,
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(completion_response=response_json,
model_response=model_response,
task=task,
encoding=encoding,
input_text=input_text,
model=model,
optional_params=optional_params)
except Exception as e:
if isinstance(e,httpx.TimeoutException):
return self.convert_to_model_response_object(
completion_response=response_json,
model_response=model_response,
task=task,
encoding=encoding,
input_text=input_text,
model=model,
optional_params=optional_params,
)
except Exception as e:
if isinstance(e, httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
elif response is not None and hasattr(response, "text"):
raise HuggingfaceError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
elif response is not None and hasattr(response, "text"):
raise HuggingfaceError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
)
else:
raise HuggingfaceError(status_code=500, message=f"{str(e)}")
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str):
async def async_streaming(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
):
async with httpx.AsyncClient() as client:
response = client.stream(
"POST",
url=f"{api_base}",
json=data,
headers=headers
)
async with response as r:
"POST", url=f"{api_base}", json=data, headers=headers
)
async with response as r:
if r.status_code != 200:
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
raise HuggingfaceError(
status_code=r.status_code,
message="An error occurred while streaming",
)
streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self,
def embedding(
self,
model: str,
input: list,
api_key: Optional[str] = None,
@ -523,65 +650,70 @@ class Huggingface(BaseLLM):
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences")
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingfaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = {
"inputs": {
"source_sentence": input[0],
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ]
"source_sentence": input[0],
"sentences": [
"That is a happy dog",
"That is a very happy person",
"Today is a sunny day",
],
}
}
else:
data = {
"inputs": input # type: ignore
}
data = {"inputs": input} # type: ignore
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": embed_url},
)
## COMPLETION CALL
response = requests.post(
embed_url, headers=headers, data=json.dumps(data)
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings['error'])
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
output_data = []
if "similarities" in embeddings:
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
}
)
else:
{
"object": "embedding",
"index": idx,
"embedding": embedding, # flatten list returned from hf
}
)
else:
for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float):
if isinstance(embedding, float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
"embedding": embedding, # flatten list returned from hf
}
)
elif isinstance(embedding, list) and isinstance(embedding[0], float):
@ -589,15 +721,17 @@ class Huggingface(BaseLLM):
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
"embedding": embedding, # flatten list returned from hf
}
)
else:
else:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][0] # flatten list returned from hf
"embedding": embedding[0][
0
], # flatten list returned from hf
}
)
model_response["object"] = "list"
@ -605,13 +739,10 @@ class Huggingface(BaseLLM):
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
input_tokens += len(encoding.encode(text))
model_response["usage"] = {
"prompt_tokens": input_tokens,
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
class MaritalkError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -15,24 +16,26 @@ class MaritalkError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class MaritTalkConfig():
class MaritTalkConfig:
"""
The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters:
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1.
- `model` (string): The model used for conversation. Default is 'maritalk'.
- `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True.
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7.
- `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95.
- `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1.
- `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped.
"""
max_tokens: Optional[int] = None
model: Optional[str] = None
do_sample: Optional[bool] = None
@ -41,27 +44,40 @@ class MaritTalkConfig():
repetition_penalty: Optional[float] = None
stopping_tokens: Optional[List[str]] = None
def __init__(self,
max_tokens: Optional[int]=None,
model: Optional[str] = None,
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
stopping_tokens: Optional[List[str]] = None) -> None:
def __init__(
self,
max_tokens: Optional[int] = None,
model: Optional[str] = None,
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
stopping_tokens: Optional[List[str]] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
headers = {
"accept": "application/json",
@ -71,6 +87,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Key {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -89,9 +106,11 @@ def completion(
model = model
## Load Config
config=litellm.MaritTalkConfig.get_config()
config = litellm.MaritTalkConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -101,24 +120,27 @@ def completion(
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
input=messages,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
@ -130,15 +152,17 @@ def completion(
else:
try:
if len(completion_response["answer"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response["answer"]
model_response["choices"][0]["message"][
"content"
] = completion_response["answer"]
except Exception as e:
raise MaritalkError(message=response.text, status_code=response.status_code)
raise MaritalkError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE
prompt = "".join(m["content"] for m in messages)
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -148,11 +172,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding(
model: str,
input: list,
@ -161,4 +186,4 @@ def embedding(
model_response=None,
encoding=None,
):
pass
pass

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
import litellm
from litellm.utils import ModelResponse, Usage
class NLPCloudError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -15,7 +16,8 @@ class NLPCloudError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class NLPCloudConfig():
class NLPCloudConfig:
"""
Reference: https://docs.nlpcloud.com/#generation
@ -43,45 +45,57 @@ class NLPCloudConfig():
- `num_return_sequences` (int): Optional. The number of independently computed returned sequences.
"""
max_length: Optional[int]=None
length_no_input: Optional[bool]=None
end_sequence: Optional[str]=None
remove_end_sequence: Optional[bool]=None
remove_input: Optional[bool]=None
bad_words: Optional[list]=None
temperature: Optional[float]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
repetition_penalty: Optional[float]=None
num_beams: Optional[int]=None
num_return_sequences: Optional[int]=None
max_length: Optional[int] = None
length_no_input: Optional[bool] = None
end_sequence: Optional[str] = None
remove_end_sequence: Optional[bool] = None
remove_input: Optional[bool] = None
bad_words: Optional[list] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
num_beams: Optional[int] = None
num_return_sequences: Optional[int] = None
def __init__(self,
max_length: Optional[int]=None,
length_no_input: Optional[bool]=None,
end_sequence: Optional[str]=None,
remove_end_sequence: Optional[bool]=None,
remove_input: Optional[bool]=None,
bad_words: Optional[list]=None,
temperature: Optional[float]=None,
top_p: Optional[float]=None,
top_k: Optional[int]=None,
repetition_penalty: Optional[float]=None,
num_beams: Optional[int]=None,
num_return_sequences: Optional[int]=None) -> None:
def __init__(
self,
max_length: Optional[int] = None,
length_no_input: Optional[bool] = None,
end_sequence: Optional[str] = None,
remove_end_sequence: Optional[bool] = None,
remove_input: Optional[bool] = None,
bad_words: Optional[list] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
num_beams: Optional[int] = None,
num_return_sequences: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
@ -93,6 +107,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Token {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -110,9 +125,11 @@ def completion(
headers = validate_environment(api_key)
## Load Config
config = litellm.NLPCloudConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.NLPCloudConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
completion_url_fragment_1 = api_base
@ -129,24 +146,31 @@ def completion(
## LOGGING
logging_obj.pre_call(
input=text,
api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url},
)
input=text,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
return clean_and_iterate_chunks(response)
else:
## LOGGING
logging_obj.post_call(
input=text,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
input=text,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
@ -161,11 +185,16 @@ def completion(
else:
try:
if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response["generated_text"]
model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
except:
raise NLPCloudError(message=json.dumps(completion_response), status_code=response.status_code)
raise NLPCloudError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = completion_response["nb_input_tokens"]
completion_tokens = completion_response["nb_generated_tokens"]
@ -174,7 +203,7 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
@ -187,25 +216,27 @@ def completion(
# # Perform further processing based on your needs
# return cleaned_chunk
# for line in response.iter_lines():
# if line:
# yield process_chunk(line)
def clean_and_iterate_chunks(response):
buffer = b''
buffer = b""
for chunk in response.iter_content(chunk_size=1024):
if not chunk:
break
buffer += chunk
while b'\x00' in buffer:
buffer = buffer.replace(b'\x00', b'')
yield buffer.decode('utf-8')
buffer = b''
while b"\x00" in buffer:
buffer = buffer.replace(b"\x00", b"")
yield buffer.decode("utf-8")
buffer = b""
# No more data expected, yield any remaining data in the buffer
if buffer:
yield buffer.decode('utf-8')
yield buffer.decode("utf-8")
def embedding():
# logic for parsing in - calling - parsing out model embedding calls

View file

@ -2,10 +2,11 @@ import requests, types, time
import json, uuid
import traceback
from typing import Optional
import litellm
import litellm
import httpx, aiohttp, asyncio
from .prompt_templates.factory import prompt_factory, custom_prompt
class OllamaError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -16,14 +17,15 @@ class OllamaError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class OllamaConfig():
class OllamaConfig:
"""
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
- `mirostat` (int): Enable Mirostat sampling for controlling perplexity. Default is 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0. Example usage: mirostat 0
- `mirostat_eta` (float): Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. Default: 0.1. Example usage: mirostat_eta 0.1
- `mirostat_tau` (float): Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. Default: 5.0. Example usage: mirostat_tau 5.0
@ -56,102 +58,134 @@ class OllamaConfig():
- `template` (string): the full prompt or prompt template (overrides what is defined in the Modelfile)
"""
mirostat: Optional[int]=None
mirostat_eta: Optional[float]=None
mirostat_tau: Optional[float]=None
num_ctx: Optional[int]=None
num_gqa: Optional[int]=None
num_thread: Optional[int]=None
repeat_last_n: Optional[int]=None
repeat_penalty: Optional[float]=None
temperature: Optional[float]=None
stop: Optional[list]=None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
tfs_z: Optional[float]=None
num_predict: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
system: Optional[str]=None
template: Optional[str]=None
def __init__(self,
mirostat: Optional[int]=None,
mirostat_eta: Optional[float]=None,
mirostat_tau: Optional[float]=None,
num_ctx: Optional[int]=None,
num_gqa: Optional[int]=None,
num_thread: Optional[int]=None,
repeat_last_n: Optional[int]=None,
repeat_penalty: Optional[float]=None,
temperature: Optional[float]=None,
stop: Optional[list]=None,
tfs_z: Optional[float]=None,
num_predict: Optional[int]=None,
top_k: Optional[int]=None,
top_p: Optional[float]=None,
system: Optional[str]=None,
template: Optional[str]=None) -> None:
mirostat: Optional[int] = None
mirostat_eta: Optional[float] = None
mirostat_tau: Optional[float] = None
num_ctx: Optional[int] = None
num_gqa: Optional[int] = None
num_thread: Optional[int] = None
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
stop: Optional[
list
] = None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
tfs_z: Optional[float] = None
num_predict: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
system: Optional[str] = None
template: Optional[str] = None
def __init__(
self,
mirostat: Optional[int] = None,
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,
num_ctx: Optional[int] = None,
num_gqa: Optional[int] = None,
num_thread: Optional[int] = None,
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
system: Optional[str] = None,
template: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# ollama implementation
def get_ollama_response(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
acompletion: bool = False,
model_response=None,
encoding=None
):
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
acompletion: bool = False,
model_response=None,
encoding=None,
):
if api_base.endswith("/api/generate"):
url = api_base
else:
else:
url = f"{api_base}/api/generate"
## Load Config
config=litellm.OllamaConfig.get_config()
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
optional_params["stream"] = optional_params.get("stream", False)
data = {
"model": model,
"prompt": prompt,
**optional_params
}
data = {"model": model, "prompt": prompt, **optional_params}
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
"acompletion": acompletion,
},
)
if acompletion is True:
if acompletion is True:
if optional_params.get("stream", False) == True:
response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
response = ollama_async_streaming(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
else:
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
response = ollama_acompletion(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
return response
elif optional_params.get("stream", False) == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post(
url=f"{url}",
json=data,
)
url=f"{url}",
json=data,
)
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
raise OllamaError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=prompt,
@ -168,52 +202,76 @@ def get_ollama_response(
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
if optional_params.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"arguments": response_json["response"], "name": ""},
"type": "function",
}
],
)
model_response["choices"][0]["message"] = message
else:
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url,
json=data,
method="POST",
timeout=litellm.request_timeout
) as response:
try:
url=url, json=data, method="POST", timeout=litellm.request_timeout
) as response:
try:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
raise OllamaError(
status_code=response.status_code, message=response.text
)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
except Exception as e:
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
try:
client = httpx.AsyncClient()
async with client.stream(
url=f"{url}",
json=data,
method="POST",
timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code, message=response.text
)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
traceback.print_exc()
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
data["stream"] = False
try:
@ -224,10 +282,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
if resp.status != 200:
text = await resp.text()
raise OllamaError(status_code=resp.status, message=text)
## LOGGING
logging_obj.post_call(
input=data['prompt'],
input=data["prompt"],
api_key="",
original_response=resp.text,
additional_args={
@ -240,37 +298,59 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
if data.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": response_json["response"],
"name": "",
},
"type": "function",
}
],
)
model_response["choices"][0]["message"] = message
else:
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["choices"][0]["message"]["content"] = response_json[
"response"
]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data['model']
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
except Exception as e:
traceback.print_exc()
raise e
async def ollama_aembeddings(api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
model_response=None,
encoding=None):
async def ollama_aembeddings(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
model_response=None,
encoding=None,
):
if api_base.endswith("/api/embeddings"):
url = api_base
else:
else:
url = f"{api_base}/api/embeddings"
## Load Config
config=litellm.OllamaConfig.get_config()
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -290,7 +370,7 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
if response.status != 200:
text = await response.text()
raise OllamaError(status_code=response.status, message=text)
## LOGGING
logging_obj.post_call(
input=prompt,
@ -308,20 +388,16 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = len(encoding.encode(prompt))
input_tokens = len(encoding.encode(prompt))
model_response["usage"] = {
"prompt_tokens": input_tokens,
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response
return model_response

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
class OobaboogaError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -15,6 +16,7 @@ class OobaboogaError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
def validate_environment(api_key):
headers = {
"accept": "application/json",
@ -24,6 +26,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Token {api_key}"
return headers
def completion(
model: str,
messages: list,
@ -44,21 +47,24 @@ def completion(
completion_url = model
elif api_base:
completion_url = api_base
else:
raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')")
else:
raise OobaboogaError(
status_code=404,
message="API Base not set. Set one via completion(..,api_base='your-api-url')",
)
model = model
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
completion_url = completion_url + "/api/v1/generate"
data = {
"prompt": prompt,
@ -66,30 +72,35 @@ def completion(
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise OobaboogaError(message=response.text, status_code=response.status_code)
raise OobaboogaError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise OobaboogaError(
message=completion_response["error"],
@ -97,14 +108,17 @@ def completion(
)
else:
try:
model_response["choices"][0]["message"]["content"] = completion_response['results'][0]['text']
model_response["choices"][0]["message"][
"content"
] = completion_response["results"][0]["text"]
except:
raise OobaboogaError(message=json.dumps(completion_response), status_code=response.status_code)
raise OobaboogaError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
)
@ -114,11 +128,12 @@ def completion(
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

File diff suppressed because it is too large Load diff

View file

@ -1,30 +1,41 @@
from typing import List, Dict
import types
class OpenrouterConfig():
class OpenrouterConfig:
"""
Reference: https://openrouter.ai/docs#format
"""
# OpenRouter-only parameters
extra_body: Dict[str, List[str]] = {
'transforms': [] # default transforms to []
}
extra_body: Dict[str, List[str]] = {"transforms": []} # default transforms to []
def __init__(self,
transforms: List[str] = [],
models: List[str] = [],
route: str = '',
) -> None:
def __init__(
self,
transforms: List[str] = [],
models: List[str] = [],
route: str = "",
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}

View file

@ -7,17 +7,22 @@ from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
import litellm
import sys, httpx
class PalmError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat")
self.request = httpx.Request(
method="POST",
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class PalmConfig():
class PalmConfig:
"""
Reference: https://developers.generativeai.google/api/python/google/generativeai/chat
@ -37,35 +42,47 @@ class PalmConfig():
- `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output
"""
context: Optional[str]=None
examples: Optional[list]=None
temperature: Optional[float]=None
candidate_count: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
max_output_tokens: Optional[int]=None
def __init__(self,
context: Optional[str]=None,
examples: Optional[list]=None,
temperature: Optional[float]=None,
candidate_count: Optional[int]=None,
top_k: Optional[int]=None,
top_p: Optional[float]=None,
max_output_tokens: Optional[int]=None) -> None:
context: Optional[str] = None
examples: Optional[list] = None
temperature: Optional[float] = None
candidate_count: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
max_output_tokens: Optional[int] = None
def __init__(
self,
context: Optional[str] = None,
examples: Optional[list] = None,
temperature: Optional[float] = None,
candidate_count: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
max_output_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion(
@ -83,41 +100,43 @@ def completion(
try:
import google.generativeai as palm
except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai")
raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
palm.configure(api_key=api_key)
model = model
## Load Config
inference_params = copy.deepcopy(optional_params)
inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py
config = litellm.PalmConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params.pop(
"stream", None
) # palm does not support streaming, so we handle this by fake streaming in main.py
config = litellm.PalmConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": {"inference_params": inference_params}},
)
input=prompt,
api_key="",
additional_args={"complete_input_dict": {"inference_params": inference_params}},
)
## COMPLETION CALL
try:
try:
response = palm.generate_text(prompt=prompt, **inference_params)
except Exception as e:
raise PalmError(
@ -127,11 +146,11 @@ def completion(
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": {}},
)
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": {}},
)
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
completion_response = response
@ -142,22 +161,25 @@ def completion(
message_obj = Message(content=item["output"])
else:
message_obj = Message(content=None)
choice_obj = Choices(index=idx+1, message=message_obj)
choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
traceback.print_exc()
raise PalmError(message=traceback.format_exc(), status_code=response.status_code)
try:
raise PalmError(
message=traceback.format_exc(), status_code=response.status_code
)
try:
completion_response = model_response["choices"][0]["message"].get("content")
except:
raise PalmError(status_code=400, message=f"No response received. Original response - {response}")
raise PalmError(
status_code=400,
message=f"No response received. Original response - {response}",
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -165,13 +187,14 @@ def completion(
model_response["created"] = int(time.time())
model_response["model"] = "palm/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -8,6 +8,7 @@ import litellm
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
class PetalsError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -16,7 +17,8 @@ class PetalsError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class PetalsConfig():
class PetalsConfig:
"""
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
@ -30,45 +32,64 @@ class PetalsConfig():
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
- `temperature` (float, optional): This value sets the temperature for sampling.
- `top_k` (integer, optional): This value sets the limit for top-k sampling.
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
"""
max_length: Optional[int]=None
max_new_tokens: Optional[int]=litellm.max_tokens # petals requires max tokens to be set
do_sample: Optional[bool]=None
temperature: Optional[float]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
repetition_penalty: Optional[float]=None
def __init__(self,
max_length: Optional[int]=None,
max_new_tokens: Optional[int]=litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool]=None,
temperature: Optional[float]=None,
top_k: Optional[int]=None,
top_p: Optional[float]=None,
repetition_penalty: Optional[float]=None) -> None:
max_length: Optional[int] = None
max_new_tokens: Optional[
int
] = litellm.max_tokens # petals requires max tokens to be set
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
def __init__(
self,
max_length: Optional[int] = None,
max_new_tokens: Optional[
int
] = litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion(
model: str,
messages: list,
api_base: Optional[str],
api_base: Optional[str],
model_response: ModelResponse,
print_verbose: Callable,
encoding,
@ -80,96 +101,97 @@ def completion(
):
## Load Config
config = litellm.PetalsConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
if api_base:
if api_base:
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": optional_params, "api_base": api_base},
)
data = {
"model": model,
"inputs": prompt,
**optional_params
}
input=prompt,
api_key="",
additional_args={
"complete_input_dict": optional_params,
"api_base": api_base,
},
)
data = {"model": model, "inputs": prompt, **optional_params}
## COMPLETION CALL
response = requests.post(api_base, data=data)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": optional_params},
)
input=prompt,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": optional_params},
)
## RESPONSE OBJECT
try:
output_text = response.json()["outputs"]
except Exception as e:
PetalsError(status_code=response.status_code, message=str(e))
else:
else:
try:
import torch
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM # type: ignore
from petals import AutoDistributedModelForCausalLM # type: ignore
except:
raise Exception(
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
)
model = model
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, add_bos_token=False)
tokenizer = AutoTokenizer.from_pretrained(
model, use_fast=False, add_bos_token=False
)
model_obj = AutoDistributedModelForCausalLM.from_pretrained(model)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": optional_params},
)
input=prompt,
api_key="",
additional_args={"complete_input_dict": optional_params},
)
## COMPLETION CALL
inputs = tokenizer(prompt, return_tensors="pt")["input_ids"]
# optional params: max_new_tokens=1,temperature=0.9, top_p=0.6
outputs = model_obj.generate(inputs, **optional_params)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=outputs,
additional_args={"complete_input_dict": optional_params},
)
input=prompt,
api_key="",
original_response=outputs,
additional_args={"complete_input_dict": optional_params},
)
## RESPONSE OBJECT
output_text = tokenizer.decode(outputs[0])
if len(output_text) > 0:
model_response["choices"][0]["message"]["content"] = output_text
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content"))
)
@ -177,13 +199,14 @@ def completion(
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -4,11 +4,13 @@ import json
from jinja2 import Template, exceptions, Environment, meta
from typing import Optional, Any
def default_pt(messages):
return " ".join(message["content"] for message in messages)
# alpaca prompt template - for models like mythomax, etc.
def alpaca_pt(messages):
# alpaca prompt template - for models like mythomax, etc.
def alpaca_pt(messages):
prompt = custom_prompt(
role_dict={
"system": {
@ -19,59 +21,56 @@ def alpaca_pt(messages):
"pre_message": "### Instruction:\n",
"post_message": "\n\n",
},
"assistant": {
"pre_message": "### Response:\n",
"post_message": "\n\n"
}
"assistant": {"pre_message": "### Response:\n", "post_message": "\n\n"},
},
bos_token="<s>",
eos_token="</s>",
messages=messages
messages=messages,
)
return prompt
# Llama2 prompt template
def llama_2_chat_pt(messages):
prompt = custom_prompt(
role_dict={
"system": {
"pre_message": "[INST] <<SYS>>\n",
"post_message": "\n<</SYS>>\n [/INST]\n"
"post_message": "\n<</SYS>>\n [/INST]\n",
},
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
"pre_message": "[INST] ",
"post_message": " [/INST]\n"
},
"post_message": " [/INST]\n",
},
"assistant": {
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
}
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
},
},
messages=messages,
bos_token="<s>",
eos_token="</s>"
eos_token="</s>",
)
return prompt
def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
if "instruct" in model:
def ollama_pt(
model, messages
): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
if "instruct" in model:
prompt = custom_prompt(
role_dict={
"system": {
"pre_message": "### System:\n",
"post_message": "\n"
},
"system": {"pre_message": "### System:\n", "post_message": "\n"},
"user": {
"pre_message": "### User:\n",
"post_message": "\n",
},
},
"assistant": {
"pre_message": "### Response:\n",
"post_message": "\n",
}
},
},
final_prompt_value="### Response:",
messages=messages
messages=messages,
)
elif "llava" in model:
prompt = ""
@ -88,36 +87,31 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
images.append(image_url)
return {
"prompt": prompt,
"images": images
}
else:
prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages)
return {"prompt": prompt, "images": images}
else:
prompt = "".join(
m["content"]
if isinstance(m["content"], str) is str
else "".join(m["content"])
for m in messages
)
return prompt
def mistral_instruct_pt(messages):
def mistral_instruct_pt(messages):
prompt = custom_prompt(
initial_prompt_value="<s>",
role_dict={
"system": {
"pre_message": "[INST]",
"post_message": "[/INST]"
},
"user": {
"pre_message": "[INST]",
"post_message": "[/INST]"
},
"assistant": {
"pre_message": "[INST]",
"post_message": "[/INST]"
}
"system": {"pre_message": "[INST]", "post_message": "[/INST]"},
"user": {"pre_message": "[INST]", "post_message": "[/INST]"},
"assistant": {"pre_message": "[INST]", "post_message": "[/INST]"},
},
final_prompt_value="</s>",
messages=messages
messages=messages,
)
return prompt
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages):
prompt = ""
@ -125,11 +119,16 @@ def falcon_instruct_pt(messages):
if message["role"] == "system":
prompt += message["content"]
else:
prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
prompt += (
message["role"]
+ ":"
+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
)
prompt += "\n\n"
return prompt
def falcon_chat_pt(messages):
prompt = ""
for message in messages:
@ -142,6 +141,7 @@ def falcon_chat_pt(messages):
return prompt
# MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def mpt_chat_pt(messages):
prompt = ""
@ -154,18 +154,20 @@ def mpt_chat_pt(messages):
prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n"
return prompt
# WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format
def wizardcoder_pt(messages):
prompt = ""
for message in messages:
if message["role"] == "system":
prompt += message["content"] + "\n\n"
elif message["role"] == "user": # map to 'Instruction'
elif message["role"] == "user": # map to 'Instruction'
prompt += "### Instruction:\n" + message["content"] + "\n\n"
elif message["role"] == "assistant": # map to 'Response'
elif message["role"] == "assistant": # map to 'Response'
prompt += "### Response:\n" + message["content"] + "\n\n"
return prompt
# Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model
def phind_codellama_pt(messages):
prompt = ""
@ -178,13 +180,17 @@ def phind_codellama_pt(messages):
prompt += "### Assistant\n" + message["content"] + "\n\n"
return prompt
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None):
## get the tokenizer config from huggingface
bos_token = ""
eos_token = ""
if chat_template is None:
if chat_template is None:
def _get_tokenizer_config(hf_model_name):
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
url = (
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
)
# Make a GET request to fetch the JSON data
response = requests.get(url)
if response.status_code == 200:
@ -193,10 +199,14 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
return {"status": "success", "tokenizer": tokenizer_config}
else:
return {"status": "failure"}
tokenizer_config = _get_tokenizer_config(model)
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]:
if (
tokenizer_config["status"] == "failure"
or "chat_template" not in tokenizer_config["tokenizer"]
):
raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json
## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"]
bos_token = tokenizer_config["bos_token"]
eos_token = tokenizer_config["eos_token"]
@ -204,10 +214,10 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
def raise_exception(message):
raise Exception(f"Error message - {message}")
# Create a template object from the template text
env = Environment()
env.globals['raise_exception'] = raise_exception
env.globals["raise_exception"] = raise_exception
try:
template = env.from_string(chat_template)
except Exception as e:
@ -216,137 +226,167 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
def _is_system_in_template():
try:
# Try rendering the template with a system message
response = template.render(messages=[{"role": "system", "content": "test"}], eos_token= "<eos>", bos_token= "<bos>")
response = template.render(
messages=[{"role": "system", "content": "test"}],
eos_token="<eos>",
bos_token="<bos>",
)
return True
# This will be raised if Jinja attempts to render the system message and it can't
except:
return False
try:
try:
# Render the template with the provided values
if _is_system_in_template():
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=messages)
else:
if _is_system_in_template():
rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=messages
)
else:
# treat a system message as a user message, if system not in template
try:
reformatted_messages = []
for message in messages:
if message["role"] == "system":
reformatted_messages.append({"role": "user", "content": message["content"]})
for message in messages:
if message["role"] == "system":
reformatted_messages.append(
{"role": "user", "content": message["content"]}
)
else:
reformatted_messages.append(message)
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=reformatted_messages)
rendered_text = template.render(
bos_token=bos_token,
eos_token=eos_token,
messages=reformatted_messages,
)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e):
if "Conversation roles must alternate user/assistant" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(reformatted_messages)-1):
for i in range(len(reformatted_messages) - 1):
new_messages.append(reformatted_messages[i])
if reformatted_messages[i]["role"] == reformatted_messages[i+1]["role"]:
if (
reformatted_messages[i]["role"]
== reformatted_messages[i + 1]["role"]
):
if reformatted_messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""})
new_messages.append(
{"role": "assistant", "content": ""}
)
else:
new_messages.append({"role": "user", "content": ""})
new_messages.append(reformatted_messages[-1])
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages)
rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=new_messages
)
return rendered_text
except Exception as e:
except Exception as e:
raise Exception(f"Error rendering template - {str(e)}")
# Anthropic template
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
# Anthropic template
def claude_2_1_pt(
messages: list,
): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
"""
Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human:
Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human:
- you can't just pass a system message
- you can't pass a system message and follow that with an assistant message
- you can't pass a system message and follow that with an assistant message
if system message is passed in, you can only do system, human, assistant or system, human
if a system message is passed in and followed by an assistant message, insert a blank human message between them.
if a system message is passed in and followed by an assistant message, insert a blank human message between them.
"""
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
prompt = ""
for idx, message in enumerate(messages):
prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "user":
prompt += (
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
)
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
elif message["role"] == "system":
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
elif message["role"] == "assistant":
if idx > 0 and messages[idx - 1]["role"] == "system":
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message
prompt += (
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
)
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
if idx > 0 and messages[idx - 1]["role"] == "system":
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message
prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
return prompt
### TOGETHER AI
### TOGETHER AI
def get_model_info(token, model):
try:
headers = {
'Authorization': f'Bearer {token}'
}
response = requests.get('https://api.together.xyz/models/info', headers=headers)
try:
headers = {"Authorization": f"Bearer {token}"}
response = requests.get("https://api.together.xyz/models/info", headers=headers)
if response.status_code == 200:
model_info = response.json()
for m in model_info:
if m["name"].lower().strip() == model.strip():
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
for m in model_info:
if m["name"].lower().strip() == model.strip():
return m["config"].get("prompt_format", None), m["config"].get(
"chat_template", None
)
return None, None
else:
return None, None
except Exception as e: # safely fail a prompt template request
except Exception as e: # safely fail a prompt template request
return None, None
def format_prompt_togetherai(messages, prompt_format, chat_template):
if prompt_format is None:
return default_pt(messages)
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
human_prompt, assistant_prompt = prompt_format.split("{prompt}")
if chat_template is not None:
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template)
elif prompt_format is not None:
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt)
else:
prompt = hf_chat_template(
model=None, messages=messages, chat_template=chat_template
)
elif prompt_format is not None:
prompt = custom_prompt(
role_dict={},
messages=messages,
initial_prompt_value=human_prompt,
final_prompt_value=assistant_prompt,
)
else:
prompt = default_pt(messages)
return prompt
return prompt
###
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
def anthropic_pt(
messages: list,
): # format - https://docs.anthropic.com/claude/reference/complete_post
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
prompt = ""
for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
prompt = ""
for idx, message in enumerate(
messages
): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
if message["role"] == "user":
prompt += (
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
)
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
elif message["role"] == "system":
prompt += (
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
)
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
else:
prompt += (
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
)
if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: `
prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
if (
idx == 0 and message["role"] == "assistant"
): # ensure the prompt always starts with `\n\nHuman: `
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt
prompt += f"{AnthropicConstants.AI_PROMPT.value}"
return prompt
return prompt
def gemini_text_image_pt(messages: list):
def gemini_text_image_pt(messages: list):
"""
{
"contents":[
@ -367,13 +407,15 @@ def gemini_text_image_pt(messages: list):
try:
import google.generativeai as genai
except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai")
raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
prompt = ""
images = []
for message in messages:
images = []
for message in messages:
if isinstance(message["content"], str):
prompt += message["content"]
prompt += message["content"]
elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]:
@ -383,45 +425,63 @@ def gemini_text_image_pt(messages: list):
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
images.append(image_url)
content = [prompt] + images
return content
# Function call template
# Function call template
def function_call_prompt(messages: list, functions: list):
function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:"
for function in functions:
function_prompt = (
"Produce JSON OUTPUT ONLY! The following functions are available to you:"
)
for function in functions:
function_prompt += f"""\n{function}\n"""
function_added_to_prompt = False
for message in messages:
if "system" in message["role"]:
message['content'] += f"""{function_prompt}"""
for message in messages:
if "system" in message["role"]:
message["content"] += f"""{function_prompt}"""
function_added_to_prompt = True
if function_added_to_prompt == False:
messages.append({'role': 'system', 'content': f"""{function_prompt}"""})
if function_added_to_prompt == False:
messages.append({"role": "system", "content": f"""{function_prompt}"""})
return messages
# Custom prompt template
def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str="", bos_token: str="", eos_token: str=""):
def custom_prompt(
role_dict: dict,
messages: list,
initial_prompt_value: str = "",
final_prompt_value: str = "",
bos_token: str = "",
eos_token: str = "",
):
prompt = bos_token + initial_prompt_value
bos_open = True
## a bos token is at the start of a system / human message
## an eos token is at the end of the assistant response to the message
for message in messages:
role = message["role"]
if role in ["system", "human"] and not bos_open:
prompt += bos_token
bos_open = True
pre_message_str = role_dict[role]["pre_message"] if role in role_dict and "pre_message" in role_dict[role] else ""
post_message_str = role_dict[role]["post_message"] if role in role_dict and "post_message" in role_dict[role] else ""
pre_message_str = (
role_dict[role]["pre_message"]
if role in role_dict and "pre_message" in role_dict[role]
else ""
)
post_message_str = (
role_dict[role]["post_message"]
if role in role_dict and "post_message" in role_dict[role]
else ""
)
prompt += pre_message_str + message["content"] + post_message_str
if role == "assistant":
prompt += eos_token
bos_open = False
@ -429,25 +489,35 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += final_prompt_value
return prompt
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
def prompt_factory(
model: str,
messages: list,
custom_llm_provider: Optional[str] = None,
api_key: Optional[str] = None,
):
original_model_name = model
model = model.lower()
if custom_llm_provider == "ollama":
if custom_llm_provider == "ollama":
return ollama_pt(model=model, messages=messages)
elif custom_llm_provider == "anthropic":
if "claude-2.1" in model:
if "claude-2.1" in model:
return claude_2_1_pt(messages=messages)
else:
else:
return anthropic_pt(messages=messages)
elif custom_llm_provider == "together_ai":
elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template)
elif custom_llm_provider == "gemini":
return format_prompt_togetherai(
messages=messages, prompt_format=prompt_format, chat_template=chat_template
)
elif custom_llm_provider == "gemini":
return gemini_text_image_pt(messages=messages)
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
elif (
"tiiuae/falcon" in model
): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
if model == "tiiuae/falcon-180B-chat":
return falcon_chat_pt(messages=messages)
elif "instruct" in model:
@ -457,17 +527,26 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return mpt_chat_pt(messages=messages)
elif "codellama/codellama" in model or "togethercomputer/codellama" in model:
if "instruct" in model:
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions
return llama_2_chat_pt(
messages=messages
) # https://huggingface.co/blog/codellama#conversational-instructions
elif "wizardlm/wizardcoder" in model:
return wizardcoder_pt(messages=messages)
elif "phind/phind-codellama" in model:
return phind_codellama_pt(messages=messages)
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
elif "togethercomputer/llama-2" in model and (
"instruct" in model or "chat" in model
):
return llama_2_chat_pt(messages=messages)
elif model in ["gryphe/mythomax-l2-13b", "gryphe/mythomix-l2-13b", "gryphe/mythologic-l2-13b"]:
return alpaca_pt(messages=messages)
else:
elif model in [
"gryphe/mythomax-l2-13b",
"gryphe/mythomix-l2-13b",
"gryphe/mythologic-l2-13b",
]:
return alpaca_pt(messages=messages)
else:
return hf_chat_template(original_model_name, messages)
except Exception as e:
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
return default_pt(
messages=messages
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

View file

@ -4,81 +4,100 @@ import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
import litellm
import litellm
import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class ReplicateError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.replicate.com/v1/deployments")
self.request = httpx.Request(
method="POST", url="https://api.replicate.com/v1/deployments"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class ReplicateConfig():
class ReplicateConfig:
"""
Reference: https://replicate.com/meta/llama-2-70b-chat/api
- `prompt` (string): The prompt to send to the model.
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'.
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
- `debug` (boolean): If set to `True`, it provides debugging output in logs.
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
"""
system_prompt: Optional[str]=None
max_new_tokens: Optional[int]=None
min_new_tokens: Optional[int]=None
temperature: Optional[int]=None
top_p: Optional[int]=None
top_k: Optional[int]=None
stop_sequences: Optional[str]=None
seed: Optional[int]=None
debug: Optional[bool]=None
def __init__(self,
system_prompt: Optional[str]=None,
max_new_tokens: Optional[int]=None,
min_new_tokens: Optional[int]=None,
temperature: Optional[int]=None,
top_p: Optional[int]=None,
top_k: Optional[int]=None,
stop_sequences: Optional[str]=None,
seed: Optional[int]=None,
debug: Optional[bool]=None) -> None:
system_prompt: Optional[str] = None
max_new_tokens: Optional[int] = None
min_new_tokens: Optional[int] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
stop_sequences: Optional[str] = None
seed: Optional[int] = None
debug: Optional[bool] = None
def __init__(
self,
system_prompt: Optional[str] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
stop_sequences: Optional[str] = None,
seed: Optional[int] = None,
debug: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# Function to start a prediction and get the prediction URL
def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose):
def start_prediction(
version_id, input_data, api_token, api_base, logging_obj, print_verbose
):
base_url = api_base
if "deployments" in version_id:
print_verbose("\nLiteLLM: Request to custom replicate deployment")
@ -88,7 +107,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
initial_prediction_data = {
@ -98,24 +117,33 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
## LOGGING
logging_obj.pre_call(
input=input_data["prompt"],
api_key="",
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers, "api_base": base_url},
input=input_data["prompt"],
api_key="",
additional_args={
"complete_input_dict": initial_prediction_data,
"headers": headers,
"api_base": base_url,
},
)
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers)
response = requests.post(
f"{base_url}/predictions", json=initial_prediction_data, headers=headers
)
if response.status_code == 201:
response_data = response.json()
return response_data.get("urls", {}).get("get")
else:
raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}")
raise ReplicateError(
response.status_code, f"Failed to start prediction {response.text}"
)
# Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose):
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
status = ""
@ -127,18 +155,22 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
if response.status_code == 200:
response_data = response.json()
if "output" in response_data:
output_string = "".join(response_data['output'])
output_string = "".join(response_data["output"])
print_verbose(f"Non-streamed output:{output_string}")
status = response_data.get('status', None)
status = response_data.get("status", None)
logs = response_data.get("logs", "")
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}, \nReplicate logs:{logs}")
raise ReplicateError(
status_code=400,
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose("Replicate: Failed to fetch prediction status and output.")
return output_string, logs
# Function to handle prediction response (streaming)
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
previous_output = ""
@ -146,30 +178,34 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
time.sleep(0.5) # prevent being rate limited by replicate
time.sleep(0.5) # prevent being rate limited by replicate
print_verbose(f"replicate: polling endpoint: {prediction_url}")
response = requests.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
status = response_data['status']
status = response_data["status"]
if "output" in response_data:
output_string = "".join(response_data['output'])
new_output = output_string[len(previous_output):]
output_string = "".join(response_data["output"])
new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status}
previous_output = output_string
status = response_data['status']
status = response_data["status"]
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}")
raise ReplicateError(
status_code=400, message=f"Error: {replicate_error}"
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")
print_verbose(
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
)
# Function to extract version ID from model string
def model_to_version_id(model):
@ -178,11 +214,12 @@ def model_to_version_id(model):
return split_model[1]
return model
# Main function for prediction completion
def completion(
model: str,
messages: list,
api_base: str,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
@ -196,35 +233,37 @@ def completion(
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)
## Load Config
config = litellm.ReplicateConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.ReplicateConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
system_prompt = None
if optional_params is not None and "supports_system_prompt" in optional_params:
supports_sys_prompt = optional_params.pop("supports_system_prompt")
else:
supports_sys_prompt = False
if supports_sys_prompt:
for i in range(len(messages)):
if messages[i]["role"] == "system":
first_sys_message = messages.pop(i)
system_prompt = first_sys_message["content"]
break
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
@ -233,43 +272,58 @@ def completion(
input_data = {
"prompt": prompt,
"system_prompt": system_prompt,
**optional_params
**optional_params,
}
# Otherwise, use the prompt as is
else:
input_data = {
"prompt": prompt,
**optional_params
}
input_data = {"prompt": prompt, **optional_params}
## COMPLETION CALL
## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url
## Step2: Poll prediction url for response
## Step2: is handled with and without streaming
model_response["created"] = int(time.time()) # for pricing this must remain right before calling api
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose)
model_response["created"] = int(
time.time()
) # for pricing this must remain right before calling api
prediction_url = start_prediction(
version_id,
input_data,
api_key,
api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
)
print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True:
print_verbose("streaming request")
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
return handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
)
else:
result, logs = handle_prediction_response(prediction_url, api_key, print_verbose)
model_response["ended"] = time.time() # for pricing this must remain right after calling api
result, logs = handle_prediction_response(
prediction_url, api_key, print_verbose
)
model_response[
"ended"
] = time.time() # for pricing this must remain right after calling api
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=result,
additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, },
input=prompt,
api_key="",
original_response=result,
additional_args={
"complete_input_dict": input_data,
"logs": logs,
"api_base": prediction_url,
},
)
print_verbose(f"raw model_response: {result}")
if len(result) == 0: # edge case, where result from replicate is empty
if len(result) == 0: # edge case, where result from replicate is empty
result = " "
## Building RESPONSE OBJECT
@ -278,12 +332,14 @@ def completion(
# Calculate usage
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"].get("content", "")))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["model"] = "replicate/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response

View file

@ -11,42 +11,61 @@ from copy import deepcopy
import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class SagemakerError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker")
self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class SagemakerConfig():
class SagemakerConfig:
"""
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
"""
max_new_tokens: Optional[int]=None
top_p: Optional[float]=None
temperature: Optional[float]=None
return_full_text: Optional[bool]=None
def __init__(self,
max_new_tokens: Optional[int]=None,
top_p: Optional[float]=None,
temperature: Optional[float]=None,
return_full_text: Optional[bool]=None) -> None:
max_new_tokens: Optional[int] = None
top_p: Optional[float] = None
temperature: Optional[float] = None
return_full_text: Optional[bool] = None
def __init__(
self,
max_new_tokens: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
return_full_text: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
"""
SAGEMAKER AUTH Keys/Vars
os.environ['AWS_ACCESS_KEY_ID'] = ""
@ -55,6 +74,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
def completion(
model: str,
messages: list,
@ -85,28 +105,30 @@ def completion(
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME") or
"us-west-2" # default to us-west-2 if user not specified
get_secret("AWS_REGION_NAME")
or "us-west-2" # default to us-west-2 if user not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
inference_params.pop("stream", None)
## Load Config
config = litellm.SagemakerConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.SagemakerConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
model = model
@ -114,25 +136,26 @@ def completion(
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
else:
if hf_model_name is None:
if "llama-2" in model.lower(): # llama-2 model
if "chat" in model.lower(): # apply llama2 chat template
if "llama-2" in model.lower(): # llama-2 model
if "chat" in model.lower(): # apply llama2 chat template
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template
else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
data = json.dumps({
"inputs": prompt,
"parameters": inference_params
}).encode('utf-8')
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
"utf-8"
)
## LOGGING
request_str = f"""
@ -142,31 +165,35 @@ def completion(
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
""" # type: ignore
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name},
)
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
"hf_model_name": hf_model_name,
},
)
## COMPLETION CALL
try:
try:
response = client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
response = response["Body"].read().decode("utf8")
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
completion_response = json.loads(response)
@ -177,19 +204,20 @@ def completion(
completion_output += completion_response_choices["generation"]
elif "generated_text" in completion_response_choices:
completion_output += completion_response_choices["generated_text"]
# check if the prompt template is part of output, if so - filter it out
# check if the prompt template is part of output, if so - filter it out
if completion_output.startswith(prompt) and "<s>" in prompt:
completion_output = completion_output.replace(prompt, "", 1)
model_response["choices"][0]["message"]["content"] = completion_output
except:
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
@ -197,28 +225,32 @@ def completion(
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding(model: str,
input: list,
model_response: EmbeddingResponse,
print_verbose: Callable,
encoding,
logging_obj,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None):
def embedding(
model: str,
input: list,
model_response: EmbeddingResponse,
print_verbose: Callable,
encoding,
logging_obj,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
):
"""
Supports Huggingface Jumpstart embeddings like GPT-6B
Supports Huggingface Jumpstart embeddings like GPT-6B
"""
### BOTO3 INIT
import boto3
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
@ -234,34 +266,34 @@ def embedding(model: str,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME") or
"us-west-2" # default to us-west-2 if user not specified
get_secret("AWS_REGION_NAME")
or "us-west-2" # default to us-west-2 if user not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
inference_params.pop("stream", None)
## Load Config
config = litellm.SagemakerConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.SagemakerConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
#### HF EMBEDDING LOGIC
data = json.dumps({
"text_inputs": input
}).encode('utf-8')
#### HF EMBEDDING LOGIC
data = json.dumps({"text_inputs": input}).encode("utf-8")
## LOGGING
request_str = f"""
@ -270,67 +302,65 @@ def embedding(model: str,
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)""" # type: ignore
)""" # type: ignore
logging_obj.pre_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
)
input=input,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
)
## EMBEDDING CALL
try:
try:
response = client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=response,
)
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=response,
)
response = json.loads(response["Body"].read().decode("utf8"))
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
input=input,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response}")
if "embedding" not in response:
if "embedding" not in response:
raise SagemakerError(status_code=500, message="embedding not found in response")
embeddings = response['embedding']
embeddings = response["embedding"]
if not isinstance(embeddings, list):
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}")
raise SagemakerError(
status_code=422, message=f"Response not in expected format - {embeddings}"
)
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
input_tokens += len(encoding.encode(text))
model_response["usage"] = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens)
return model_response

View file

@ -9,17 +9,21 @@ import httpx
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
class TogetherAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.together.xyz/inference")
self.request = httpx.Request(
method="POST", url="https://api.together.xyz/inference"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class TogetherAIConfig():
class TogetherAIConfig:
"""
Reference: https://docs.together.ai/reference/inference
@ -37,35 +41,49 @@ class TogetherAIConfig():
- `repetition_penalty` (float, optional): A number that controls the diversity of generated text by reducing the likelihood of repeated sequences. Higher values decrease repetition.
- `logprobs` (int32, optional): This parameter is not described in the prompt.
- `logprobs` (int32, optional): This parameter is not described in the prompt.
"""
max_tokens: Optional[int]=None
stop: Optional[str]=None
temperature:Optional[int]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
repetition_penalty: Optional[float]=None
logprobs: Optional[int]=None
def __init__(self,
max_tokens: Optional[int]=None,
stop: Optional[str]=None,
temperature:Optional[int]=None,
top_p: Optional[float]=None,
top_k: Optional[int]=None,
repetition_penalty: Optional[float]=None,
logprobs: Optional[int]=None) -> None:
max_tokens: Optional[int] = None
stop: Optional[str] = None
temperature: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
logprobs: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
stop: Optional[str] = None,
temperature: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
logprobs: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key):
@ -80,10 +98,11 @@ def validate_environment(api_key):
}
return headers
def completion(
model: str,
messages: list,
api_base: str,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
@ -97,9 +116,11 @@ def completion(
headers = validate_environment(api_key)
## Load Config
config = litellm.TogetherAIConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
config = litellm.TogetherAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
@ -107,15 +128,20 @@ def completion(
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list
prompt = prompt_factory(
model=model,
messages=messages,
api_key=api_key,
custom_llm_provider="together_ai",
) # api key required to query together ai model list
data = {
"model": model,
@ -128,13 +154,14 @@ def completion(
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": api_base},
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if (
"stream_tokens" in optional_params
and optional_params["stream_tokens"] == True
):
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
response = requests.post(
api_base,
headers=headers,
@ -143,18 +170,14 @@ def completion(
)
return response.iter_lines()
else:
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data)
)
response = requests.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
if response.status_code != 200:
@ -170,30 +193,38 @@ def completion(
)
elif "error" in completion_response["output"]:
raise TogetherAIError(
message=json.dumps(completion_response["output"]), status_code=response.status_code
message=json.dumps(completion_response["output"]),
status_code=response.status_code,
)
if len(completion_response["output"]["choices"][0]["text"]) >= 0:
model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"]
model_response["choices"][0]["message"]["content"] = completion_response[
"output"
]["choices"][0]["text"]
## CALCULATING USAGE
print_verbose(f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}")
print_verbose(
f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
if "finish_reason" in completion_response["output"]["choices"][0]:
model_response.choices[0].finish_reason = completion_response["output"]["choices"][0]["finish_reason"]
model_response.choices[0].finish_reason = completion_response["output"][
"choices"
][0]["finish_reason"]
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm
import httpx
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url=" https://cloud.google.com/vertex-ai/")
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAIConfig():
class VertexAIConfig:
"""
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
@ -34,28 +38,42 @@ class VertexAIConfig():
Note: Please make sure to modify the default parameters as required for your use case.
"""
temperature: Optional[float]=None
max_output_tokens: Optional[int]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
def __init__(self,
temperature: Optional[float]=None,
max_output_tokens: Optional[int]=None,
top_p: Optional[float]=None,
top_k: Optional[int]=None) -> None:
temperature: Optional[float] = None
max_output_tokens: Optional[int] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
def __init__(
self,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def _get_image_bytes_from_url(image_url: str) -> bytes:
try:
@ -65,7 +83,7 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
return image_bytes
except requests.exceptions.RequestException as e:
# Handle any request exceptions (e.g., connection error, timeout)
return b'' # Return an empty bytes object or handle the error as needed
return b"" # Return an empty bytes object or handle the error as needed
def _load_image_from_url(image_url: str):
@ -78,13 +96,18 @@ def _load_image_from_url(image_url: str):
Returns:
Image: The loaded image.
"""
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
image_bytes = _get_image_bytes_from_url(image_url)
return Image.from_bytes(image_bytes)
def _gemini_vision_convert_messages(
messages: list
):
def _gemini_vision_convert_messages(messages: list):
"""
Converts given messages for GPT-4 Vision to Gemini format.
@ -95,7 +118,7 @@ def _gemini_vision_convert_messages(
Returns:
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
Raises:
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
Exception: If any other exception occurs during the execution of the function.
@ -115,11 +138,23 @@ def _gemini_vision_convert_messages(
try:
import vertexai
except:
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
try:
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
try:
from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
InputOutputTextPair,
)
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
# given messages for gpt-4 vision, convert them for gemini
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
@ -159,6 +194,7 @@ def _gemini_vision_convert_messages(
except Exception as e:
raise e
def completion(
model: str,
messages: list,
@ -171,30 +207,38 @@ def completion(
optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool=False
acompletion: bool = False,
):
try:
import vertexai
except:
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
try:
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
try:
from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
InputOutputTextPair,
)
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
)
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
vertexai.init(
project=vertex_project, location=vertex_location
)
vertexai.init(project=vertex_project, location=vertex_location)
## Load Config
config = litellm.VertexAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
## Process safety settings into format expected by vertex AI
## Process safety settings into format expected by vertex AI
safety_settings = None
if "safety_settings" in optional_params:
safety_settings = optional_params.pop("safety_settings")
@ -202,17 +246,25 @@ def completion(
raise ValueError("safety_settings must be a list")
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
raise ValueError("safety_settings must be a list of dicts")
safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings]
safety_settings = [
gapic_content_types.SafetySetting(x) for x in safety_settings
]
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)])
prompt = " ".join(
[
message["content"]
for message in messages
if isinstance(message["content"], str)
]
)
mode = ""
mode = ""
request_str = ""
response_obj = None
if model in litellm.vertex_language_models:
if model in litellm.vertex_language_models:
llm_model = GenerativeModel(model)
mode = ""
request_str += f"llm_model = GenerativeModel({model})\n"
@ -232,31 +284,76 @@ def completion(
llm_model = CodeGenerationModel.from_pretrained(model)
mode = "text"
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
else: # vertex_code_llm_models
else: # vertex_code_llm_models
llm_model = CodeChatModel.from_pretrained(model)
mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
if acompletion == True: # [TODO] expand support to vertex ai chat + text models
if optional_params.get("stream", False) is True:
if acompletion == True: # [TODO] expand support to vertex ai chat + text models
if optional_params.get("stream", False) is True:
# async streaming
return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, messages=messages, print_verbose=print_verbose, **optional_params)
return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, messages=messages,print_verbose=print_verbose,**optional_params)
return async_streaming(
llm_model=llm_model,
mode=mode,
prompt=prompt,
logging_obj=logging_obj,
request_str=request_str,
model=model,
model_response=model_response,
messages=messages,
print_verbose=print_verbose,
**optional_params,
)
return async_completion(
llm_model=llm_model,
mode=mode,
prompt=prompt,
logging_obj=logging_obj,
request_str=request_str,
model=model,
model_response=model_response,
encoding=encoding,
messages=messages,
print_verbose=print_verbose,
**optional_params,
)
if mode == "":
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
model_response = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=stream,
)
optional_params["stream"] = True
return model_response
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
@ -268,21 +365,35 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content(
contents=content,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=True
stream=True,
)
optional_params["stream"] = True
return model_response
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
## LLM Call
response = llm_model.generate_content(
contents=content,
@ -293,88 +404,150 @@ def completion(
response_obj = response._raw_response
elif mode == "chat":
chat = llm_model.start_chat()
request_str+= f"chat = llm_model.start_chat()\n"
request_str += f"chat = llm_model.start_chat()\n"
if "stream" in optional_params and optional_params["stream"] == True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n"
optional_params.pop(
"stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text":
if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
optional_params.pop(
"stream", None
) # See note above on handling streaming for vertex ai
request_str += (
f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = llm_model.predict(prompt, **optional_params).text
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count)
else:
prompt_tokens = len(
encoding.encode(prompt)
)
model_response["choices"][0].finish_reason = response_obj.candidates[
0
].finish_reason.name
usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count,
)
else:
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
except Exception as e:
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, messages = None, print_verbose = None, **optional_params):
async def async_completion(
llm_model,
mode: str,
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
encoding=None,
messages=None,
print_verbose=None,
**optional_params,
):
"""
Add support for acompletion calls for gemini-pro
"""
try:
try:
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params)
)
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "vision":
@ -386,12 +559,18 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
## LLM Call
response = await llm_model._generate_content_async(
contents=content,
generation_config=GenerationConfig(**optional_params)
contents=content, generation_config=GenerationConfig(**optional_params)
)
completion_response = response.text
response_obj = response._raw_response
@ -399,14 +578,28 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
# chat-bison etc.
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text
elif mode == "text":
# gecko etc.
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text
@ -416,51 +609,77 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count)
model_response["choices"][0].finish_reason = response_obj.candidates[
0
].finish_reason.name
usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count,
)
else:
prompt_tokens = len(
encoding.encode(prompt)
)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
except Exception as e:
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, messages = None, print_verbose = None, **optional_params):
async def async_streaming(
llm_model,
mode: str,
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
messages=None,
print_verbose=None,
**optional_params,
):
"""
Add support for async streaming calls for gemini-pro
"""
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
stream = optional_params.pop("stream")
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params), stream=stream
)
optional_params["stream"] = True
elif mode == "vision":
elif mode == "vision":
stream = optional_params.pop("stream")
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
@ -470,33 +689,68 @@ async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_r
content = [prompt] + images
stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = llm_model._generate_content_streaming_async(
contents=content,
generation_config=GenerationConfig(**optional_params),
stream=True
stream=True,
)
optional_params["stream"] = True
elif mode == "chat":
chat = llm_model.start_chat()
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
optional_params.pop(
"stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = chat.send_message_streaming_async(prompt, **optional_params)
optional_params["stream"] = True
elif mode == "text":
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
optional_params.pop(
"stream", None
) # See note above on handling streaming for vertex ai
request_str += (
f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = llm_model.predict_streaming_async(prompt, **optional_params)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="vertex_ai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -6,7 +6,10 @@ import time, httpx
from typing import Callable, Any
from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt
llm = None
class VLLMError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -17,17 +20,20 @@ class VLLMError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
# check if vllm is installed
def validate_environment(model: str):
global llm
try:
from vllm import LLM, SamplingParams # type: ignore
try:
from vllm import LLM, SamplingParams # type: ignore
if llm is None:
llm = LLM(model=model)
return llm, SamplingParams
except Exception as e:
raise VLLMError(status_code=0, message=str(e))
def completion(
model: str,
messages: list,
@ -50,15 +56,14 @@ def completion(
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -69,9 +74,10 @@ def completion(
if llm:
outputs = llm.generate(prompt, sampling_params)
else:
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm")
raise VLLMError(
status_code=0, message="Need to pass in a model name to initialize vllm"
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
return iter(outputs)
@ -88,24 +94,22 @@ def completion(
model_response["choices"][0]["message"]["content"] = outputs[0].outputs[0].text
## CALCULATING USAGE
prompt_tokens = len(outputs[0].prompt_token_ids)
completion_tokens = len(outputs[0].outputs[0].token_ids)
prompt_tokens = len(outputs[0].prompt_token_ids)
completion_tokens = len(outputs[0].outputs[0].token_ids)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def batch_completions(
model: str,
messages: list,
optional_params=None,
custom_prompt_dict={}
model: str, messages: list, optional_params=None, custom_prompt_dict={}
):
"""
Example usage:
@ -137,31 +141,33 @@ def batch_completions(
except Exception as e:
error_str = str(e)
if "data parallel group is already initialized" in error_str:
pass
pass
else:
raise VLLMError(status_code=0, message=error_str)
sampling_params = SamplingParams(**optional_params)
prompts = []
prompts = []
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
for message in messages:
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=message
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=message,
)
prompts.append(prompt)
else:
for message in messages:
prompt = prompt_factory(model=model, messages=message)
prompts.append(prompt)
if llm:
outputs = llm.generate(prompts, sampling_params)
else:
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm")
raise VLLMError(
status_code=0, message="Need to pass in a model name to initialize vllm"
)
final_outputs = []
for output in outputs:
@ -170,20 +176,21 @@ def batch_completions(
model_response["choices"][0]["message"]["content"] = output.outputs[0].text
## CALCULATING USAGE
prompt_tokens = len(output.prompt_token_ids)
completion_tokens = len(output.outputs[0].token_ids)
prompt_tokens = len(output.prompt_token_ids)
completion_tokens = len(output.outputs[0].token_ids)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
final_outputs.append(model_response)
return final_outputs
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

File diff suppressed because it is too large Load diff

View file

@ -1 +1 @@
from . import *
from . import *

View file

@ -1,4 +1,4 @@
def my_custom_rule(input): # receives the model response
def my_custom_rule(input): # receives the model response
# if len(input) < 5: # trigger fallback if the model response is too short
return False
return True
return False
return True

View file

@ -3,13 +3,15 @@ from typing import Optional, List, Union, Dict, Literal
from datetime import datetime
import uuid, json
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
@ -34,7 +36,7 @@ class ProxyChatCompletionRequest(LiteLLMBase):
tools: Optional[List[str]] = None
tool_choice: Optional[str] = None
functions: Optional[List[str]] = None # soon to be deprecated
function_call: Optional[str] = None # soon to be deprecated
function_call: Optional[str] = None # soon to be deprecated
# Optional LiteLLM params
caching: Optional[bool] = None
@ -49,7 +51,8 @@ class ProxyChatCompletionRequest(LiteLLMBase):
request_timeout: Optional[int] = None
class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(LiteLLMBase):
id: Optional[str]
@ -57,38 +60,37 @@ class ModelInfoDelete(LiteLLMBase):
class ModelInfo(LiteLLMBase):
id: Optional[str]
mode: Optional[Literal['embedding', 'chat', 'completion']]
mode: Optional[Literal["embedding", "chat", "completion"]]
input_cost_per_token: Optional[float] = 0.0
output_cost_per_token: Optional[float] = 0.0
max_tokens: Optional[int] = 2048 # assume 2048 if not set
max_tokens: Optional[int] = 2048 # assume 2048 if not set
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
# we look up the base model in model_prices_and_context_window.json
base_model: Optional[Literal
[
'gpt-4-1106-preview',
'gpt-4-32k',
'gpt-4',
'gpt-3.5-turbo-16k',
'gpt-3.5-turbo',
'text-embedding-ada-002',
]
]
base_model: Optional[
Literal[
"gpt-4-1106-preview",
"gpt-4-32k",
"gpt-4",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo",
"text-embedding-ada-002",
]
]
class Config:
extra = Extra.allow # Allow extra fields
protected_namespaces = ()
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("id") is None:
values.update({"id": str(uuid.uuid4())})
if values.get("mode") is None:
if values.get("mode") is None:
values.update({"mode": None})
if values.get("input_cost_per_token") is None:
if values.get("input_cost_per_token") is None:
values.update({"input_cost_per_token": None})
if values.get("output_cost_per_token") is None:
if values.get("output_cost_per_token") is None:
values.update({"output_cost_per_token": None})
if values.get("max_tokens") is None:
values.update({"max_tokens": None})
@ -97,21 +99,21 @@ class ModelInfo(LiteLLMBase):
return values
class ModelParams(LiteLLMBase):
model_name: str
litellm_params: dict
model_info: ModelInfo
class Config:
protected_namespaces = ()
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("model_info") is None:
values.update({"model_info": ModelInfo()})
return values
class GenerateKeyRequest(LiteLLMBase):
duration: Optional[str] = "1h"
models: Optional[list] = []
@ -122,6 +124,7 @@ class GenerateKeyRequest(LiteLLMBase):
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
class UpdateKeyRequest(LiteLLMBase):
key: str
duration: Optional[str] = None
@ -133,10 +136,12 @@ class UpdateKeyRequest(LiteLLMBase):
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
"""
Return the row in the db
"""
api_key: Optional[str] = None
models: list = []
aliases: dict = {}
@ -147,45 +152,84 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
duration: str = "1h"
metadata: dict = {}
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: Optional[datetime]
user_id: str
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
class NewUserRequest(GenerateKeyRequest):
max_budget: Optional[float] = None
class NewUserResponse(GenerateKeyResponse):
max_budget: Optional[float] = None
class ConfigGeneralSettings(LiteLLMBase):
"""
Documents all the fields supported by `general_settings` in config.yaml
"""
completion_model: Optional[str] = Field(None, description="proxy level default model for all chat completion calls")
use_azure_key_vault: Optional[bool] = Field(None, description="load keys from azure key vault")
master_key: Optional[str] = Field(None, description="require a key for all calls to proxy")
database_url: Optional[str] = Field(None, description="connect to a postgres db - needed for generating temporary keys + tracking spend / key")
otel: Optional[bool] = Field(None, description="[BETA] OpenTelemetry support - this might change, use with caution.")
custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth")
max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key")
infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)")
background_health_checks: Optional[bool] = Field(None, description="run health checks in background")
health_check_interval: int = Field(300, description="background health check interval in seconds")
completion_model: Optional[str] = Field(
None, description="proxy level default model for all chat completion calls"
)
use_azure_key_vault: Optional[bool] = Field(
None, description="load keys from azure key vault"
)
master_key: Optional[str] = Field(
None, description="require a key for all calls to proxy"
)
database_url: Optional[str] = Field(
None,
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
)
otel: Optional[bool] = Field(
None,
description="[BETA] OpenTelemetry support - this might change, use with caution.",
)
custom_auth: Optional[str] = Field(
None,
description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth",
)
max_parallel_requests: Optional[int] = Field(
None, description="maximum parallel requests for each api key"
)
infer_model_from_keys: Optional[bool] = Field(
None,
description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)",
)
background_health_checks: Optional[bool] = Field(
None, description="run health checks in background"
)
health_check_interval: int = Field(
300, description="background health check interval in seconds"
)
class ConfigYAML(LiteLLMBase):
"""
Documents all the fields supported by the config.yaml
"""
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache")
model_list: Optional[List[ModelParams]] = Field(
None,
description="List of supported models on the server, with model-specific configs",
)
litellm_settings: Optional[dict] = Field(
None,
description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache",
)
general_settings: Optional[ConfigGeneralSettings] = None
class Config:
protected_namespaces = ()

View file

@ -1,14 +1,16 @@
from litellm.proxy._types import UserAPIKeyAuth
from fastapi import Request
from dotenv import load_dotenv
import os
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"
if api_key == modified_master_key:
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception
except:
raise Exception

View file

@ -4,17 +4,19 @@ import sys, os, traceback
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
) # Adds the parent directory to the system path
from litellm.integrations.custom_logger import CustomLogger
import litellm
import inspect
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
if litellm.set_verbose:
print(print_statement) # noqa
class MyCustomHandler(CustomLogger):
def __init__(self):
@ -23,36 +25,38 @@ class MyCustomHandler(CustomLogger):
print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger")
try:
print_verbose(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
methods = [
method
for method in dir(self)
if inspect.ismethod(getattr(self, method))
]
# Pretty print_verbose the methods
for method in methods:
print_verbose(f" - {method}")
print_verbose(f"{reset_color_code}")
except:
pass
def log_pre_api_call(self, model, messages, kwargs):
def log_pre_api_call(self, model, messages, kwargs):
print_verbose(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose("On Success!")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Async Success!")
response_cost = litellm.completion_cost(completion_response=response_obj)
assert response_cost > 0.0
return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
print_verbose(f"On Async Failure !")
except Exception as e:
@ -64,4 +68,4 @@ proxy_handler_instance = MyCustomHandler()
# need to set litellm.callbacks = [customHandler] # on the proxy
# litellm.success_callback = [async_on_succes_logger]
# litellm.success_callback = [async_on_succes_logger]

View file

@ -12,25 +12,16 @@ from litellm._logging import print_verbose
logger = logging.getLogger(__name__)
ILLEGAL_DISPLAY_PARAMS = [
"messages",
"api_key"
]
ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key"]
def _get_random_llm_message():
"""
Get a random message from the LLM.
"""
messages = [
"Hey how's it going?",
"What's 1 + 1?"
]
messages = ["Hey how's it going?", "What's 1 + 1?"]
return [
{"role": "user", "content": random.choice(messages)}
]
return [{"role": "user", "content": random.choice(messages)}]
def _clean_litellm_params(litellm_params: dict):
@ -44,34 +35,40 @@ async def _perform_health_check(model_list: list):
"""
Perform a health check for each model in the list.
"""
async def _check_img_gen_model(model_params: dict):
model_params.pop("messages", None)
model_params["prompt"] = "test from litellm"
try:
await litellm.aimage_generation(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False
return True
async def _check_embedding_model(model_params: dict):
model_params.pop("messages", None)
model_params["input"] = ["test from litellm"]
try:
await litellm.aembedding(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False
return True
async def _check_model(model_params: dict):
try:
await litellm.acompletion(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
except Exception as e:
print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False
return True
tasks = []
@ -104,9 +101,9 @@ async def _perform_health_check(model_list: list):
return healthy_endpoints, unhealthy_endpoints
async def perform_health_check(model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None):
async def perform_health_check(
model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None
):
"""
Perform a health check on the system.
@ -115,7 +112,9 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl
"""
if not model_list:
if cli_model:
model_list = [{"model_name": cli_model, "litellm_params": {"model": cli_model}}]
model_list = [
{"model_name": cli_model, "litellm_params": {"model": cli_model}}
]
else:
return [], []
@ -125,5 +124,3 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)
return healthy_endpoints, unhealthy_endpoints

View file

@ -6,35 +6,42 @@ from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
import json, traceback
class MaxBudgetLimiter(CustomLogger):
class MaxBudgetLimiter(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
def print_verbose(self, print_statement):
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
try:
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
user_row = cache.get_cache(cache_key)
if user_row is None: # value not yet cached
return
if user_row is None: # value not yet cached
return
max_budget = user_row["max_budget"]
curr_spend = user_row["spend"]
if max_budget is None:
return
if curr_spend is None:
return
if curr_spend is None:
return
# CHECK IF REQUEST ALLOWED
if curr_spend >= max_budget:
raise HTTPException(status_code=429, detail="Max budget limit reached.")
except HTTPException as e:
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
traceback.print_exc()

View file

@ -5,18 +5,25 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
class MaxParallelRequestsHandler(CustomLogger):
class MaxParallelRequestsHandler(CustomLogger):
user_api_key_cache = None
# Class variables or attributes
def __init__(self):
pass
def print_verbose(self, print_statement):
if litellm.set_verbose is True:
print(print_statement) # noqa
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests
@ -26,8 +33,8 @@ class MaxParallelRequestsHandler(CustomLogger):
if max_parallel_requests is None:
return
self.user_api_key_cache = cache # save the api key cache for updating the value
self.user_api_key_cache = cache # save the api key cache for updating the value
# CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count"
@ -35,56 +42,67 @@ class MaxParallelRequestsHandler(CustomLogger):
self.print_verbose(f"current: {current}")
if current is None:
cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests:
elif int(current) < max_parallel_requests:
# Increase count for this token
cache.set_cache(request_count_api_key, int(current) + 1)
else:
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
else:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
try:
self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING")
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
if user_api_key is None:
return
if self.user_api_key_cache is None:
if self.user_api_key_cache is None:
return
request_count_api_key = f"{user_api_key}_request_count"
# check if it has collected an entire stream response
self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}")
self.print_verbose(
f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}"
)
if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
# Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
current = (
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
)
new_val = current - 1
self.print_verbose(f"updated_value in success call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e:
self.print_verbose(e) # noqa
except Exception as e:
self.print_verbose(e) # noqa
async def async_log_failure_call(self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception):
async def async_log_failure_call(
self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception
):
try:
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
api_key = user_api_key_dict.api_key
if api_key is None:
return
if self.user_api_key_cache is None:
if self.user_api_key_cache is None:
return
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
if (
hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)
):
pass # ignore failed calls due to max limit being reached
else:
request_count_api_key = f"{api_key}_request_count"
# Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
current = (
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
)
new_val = current - 1
self.print_verbose(f"updated_value in failure call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e:
self.print_verbose(f"An exception occurred - {str(e)}") # noqa
self.print_verbose(f"An exception occurred - {str(e)}") # noqa

View file

@ -6,85 +6,202 @@ from datetime import datetime
import importlib
from dotenv import load_dotenv
import operator
sys.path.append(os.getcwd())
config_filename = "litellm.secrets"
# Using appdirs to determine user-specific config path
config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
user_config_path = os.getenv(
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
)
load_dotenv()
from importlib import resources
import shutil
telemetry = None
def run_ollama_serve():
try:
command = ['ollama', 'serve']
with open(os.devnull, 'w') as devnull:
command = ["ollama", "serve"]
with open(os.devnull, "w") as devnull:
process = subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e:
print(f"""
print(
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""") # noqa
"""
) # noqa
def clone_subfolder(repo_url, subfolder, destination):
# Clone the full repo
repo_name = repo_url.split('/')[-1]
repo_master = os.path.join(destination, "repo_master")
subprocess.run(['git', 'clone', repo_url, repo_master])
# Clone the full repo
repo_name = repo_url.split("/")[-1]
repo_master = os.path.join(destination, "repo_master")
subprocess.run(["git", "clone", repo_url, repo_master])
# Move into the subfolder
subfolder_path = os.path.join(repo_master, subfolder)
# Move into the subfolder
subfolder_path = os.path.join(repo_master, subfolder)
# Copy subfolder to destination
for file_name in os.listdir(subfolder_path):
source = os.path.join(subfolder_path, file_name)
if os.path.isfile(source):
shutil.copy(source, destination)
else:
dest_path = os.path.join(destination, file_name)
shutil.copytree(source, dest_path)
# Copy subfolder to destination
for file_name in os.listdir(subfolder_path):
source = os.path.join(subfolder_path, file_name)
if os.path.isfile(source):
shutil.copy(source, destination)
else:
dest_path = os.path.join(destination, file_name)
shutil.copytree(source, dest_path)
# Remove cloned repo folder
subprocess.run(["rm", "-rf", os.path.join(destination, "repo_master")])
feature_telemetry(feature="create-proxy")
# Remove cloned repo folder
subprocess.run(['rm', '-rf', os.path.join(destination, "repo_master")])
feature_telemetry(feature="create-proxy")
def is_port_in_use(port):
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
return s.connect_ex(("localhost", port)) == 0
@click.command()
@click.option('--host', default='0.0.0.0', help='Host for the server to listen on.')
@click.option('--port', default=8000, help='Port to bind the server to.')
@click.option('--num_workers', default=1, help='Number of uvicorn workers to spin up')
@click.option('--api_base', default=None, help='API base URL.')
@click.option('--api_version', default="2023-07-01-preview", help='For azure - pass in the api version.')
@click.option('--model', '-m', default=None, help='The model name to pass to litellm expects')
@click.option('--alias', default=None, help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")')
@click.option('--add_key', default=None, help='The model name to pass to litellm expects')
@click.option('--headers', default=None, help='headers for the API call')
@click.option('--save', is_flag=True, type=bool, help='Save the model-specific config')
@click.option('--debug', default=False, is_flag=True, type=bool, help='To debug the input')
@click.option('--use_queue', default=False, is_flag=True, type=bool, help='To use celery workers for async endpoints')
@click.option('--temperature', default=None, type=float, help='Set temperature for the model')
@click.option('--max_tokens', default=None, type=int, help='Set max tokens for the model')
@click.option('--request_timeout', default=600, type=int, help='Set timeout in seconds for completion calls')
@click.option('--drop_params', is_flag=True, help='Drop any unmapped params')
@click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt')
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`')
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`')
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
@click.option('--version', '-v', default=False, is_flag=True, type=bool, help='Print LiteLLM version')
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.')
@click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml')
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response')
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with')
@click.option('--local', is_flag=True, default=False, help='for local debugging')
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health, version):
@click.option("--host", default="0.0.0.0", help="Host for the server to listen on.")
@click.option("--port", default=8000, help="Port to bind the server to.")
@click.option("--num_workers", default=1, help="Number of uvicorn workers to spin up")
@click.option("--api_base", default=None, help="API base URL.")
@click.option(
"--api_version",
default="2023-07-01-preview",
help="For azure - pass in the api version.",
)
@click.option(
"--model", "-m", default=None, help="The model name to pass to litellm expects"
)
@click.option(
"--alias",
default=None,
help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")',
)
@click.option(
"--add_key", default=None, help="The model name to pass to litellm expects"
)
@click.option("--headers", default=None, help="headers for the API call")
@click.option("--save", is_flag=True, type=bool, help="Save the model-specific config")
@click.option(
"--debug", default=False, is_flag=True, type=bool, help="To debug the input"
)
@click.option(
"--use_queue",
default=False,
is_flag=True,
type=bool,
help="To use celery workers for async endpoints",
)
@click.option(
"--temperature", default=None, type=float, help="Set temperature for the model"
)
@click.option(
"--max_tokens", default=None, type=int, help="Set max tokens for the model"
)
@click.option(
"--request_timeout",
default=600,
type=int,
help="Set timeout in seconds for completion calls",
)
@click.option("--drop_params", is_flag=True, help="Drop any unmapped params")
@click.option(
"--add_function_to_prompt",
is_flag=True,
help="If function passed but unsupported, pass it as prompt",
)
@click.option(
"--config",
"-c",
default=None,
help="Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`",
)
@click.option(
"--max_budget",
default=None,
type=float,
help="Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`",
)
@click.option(
"--telemetry",
default=True,
type=bool,
help="Helps us know if people are using this feature. Turn this off by doing `--telemetry False`",
)
@click.option(
"--version",
"-v",
default=False,
is_flag=True,
type=bool,
help="Print LiteLLM version",
)
@click.option(
"--logs",
flag_value=False,
type=int,
help='Gets the "n" most recent logs. By default gets most recent log.',
)
@click.option(
"--health",
flag_value=True,
help="Make a chat/completions request to all llms in config.yaml",
)
@click.option(
"--test",
flag_value=True,
help="proxy chat completions url to make a test request to",
)
@click.option(
"--test_async",
default=False,
is_flag=True,
help="Calls async endpoints /queue/requests and /queue/response",
)
@click.option(
"--num_requests",
default=10,
type=int,
help="Number of requests to hit async endpoint with",
)
@click.option("--local", is_flag=True, default=False, help="for local debugging")
def run_server(
host,
port,
api_base,
api_version,
model,
alias,
add_key,
headers,
save,
debug,
temperature,
max_tokens,
request_timeout,
drop_params,
add_function_to_prompt,
config,
max_budget,
telemetry,
logs,
test,
local,
num_workers,
test_async,
num_requests,
use_queue,
health,
version,
):
global feature_telemetry
args = locals()
if local:
@ -92,51 +209,60 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
else:
try:
from .proxy_server import app, save_worker_config, usage_telemetry
except ImportError as e:
except ImportError as e:
from proxy_server import app, save_worker_config, usage_telemetry
feature_telemetry = usage_telemetry
if logs is not None:
if logs == 0: # default to 1
if logs == 0: # default to 1
logs = 1
try:
with open('api_log.json') as f:
with open("api_log.json") as f:
data = json.load(f)
# convert keys to datetime objects
log_times = {datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()}
# convert keys to datetime objects
log_times = {
datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()
}
# sort by timestamp
sorted_times = sorted(log_times.items(), key=operator.itemgetter(0), reverse=True)
# sort by timestamp
sorted_times = sorted(
log_times.items(), key=operator.itemgetter(0), reverse=True
)
# get n recent logs
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]}
recent_logs = {
k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]
}
print(json.dumps(recent_logs, indent=4)) # noqa
print(json.dumps(recent_logs, indent=4)) # noqa
except:
raise Exception("LiteLLM: No logs saved!")
return
if version == True:
pkg_version = importlib.metadata.version("litellm")
click.echo(f'\nLiteLLM: Current Version = {pkg_version}\n')
click.echo(f"\nLiteLLM: Current Version = {pkg_version}\n")
return
if model and "ollama" in model and api_base is None:
if model and "ollama" in model and api_base is None:
run_ollama_serve()
if test_async is True:
if test_async is True:
import requests, concurrent, time
api_base = f"http://{host}:{port}"
def _make_openai_completion():
def _make_openai_completion():
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Write a short poem about the moon"}]
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Write a short poem about the moon"}
],
}
response = requests.post("http://0.0.0.0:8000/queue/request", json=data)
response = response.json()
while True:
try:
while True:
try:
url = response["url"]
polling_url = f"{api_base}{url}"
polling_response = requests.get(polling_url)
@ -146,7 +272,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
if status == "finished":
llm_response = polling_response["result"]
break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa
print(
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}"
) # noqa
time.sleep(0.5)
except Exception as e:
print("got exception in polling", e)
@ -159,7 +287,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
futures = []
start_time = time.time()
# Make concurrent calls
with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_calls) as executor:
with concurrent.futures.ThreadPoolExecutor(
max_workers=concurrent_calls
) as executor:
for _ in range(concurrent_calls):
futures.append(executor.submit(_make_openai_completion))
@ -171,7 +301,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
failed_calls = 0
for future in futures:
if future.done():
if future.done():
if future.result() is not None:
successful_calls += 1
else:
@ -185,58 +315,86 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
return
if health != False:
import requests
print("\nLiteLLM: Health Testing models in config")
response = requests.get(url=f"http://{host}:{port}/health")
print(json.dumps(response.json(), indent=4))
return
if test != False:
click.echo('\nLiteLLM: Making a test ChatCompletions request to your proxy')
click.echo("\nLiteLLM: Making a test ChatCompletions request to your proxy")
import openai
if test == True: # flag value set
api_base = f"http://{host}:{port}"
else:
api_base = test
client = openai.OpenAI(
api_key="My API Key",
base_url=api_base
)
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
], max_tokens=256)
click.echo(f'\nLiteLLM: response from proxy {response}')
if test == True: # flag value set
api_base = f"http://{host}:{port}"
else:
api_base = test
client = openai.OpenAI(api_key="My API Key", base_url=api_base)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "this is a test request, write a short poem",
}
],
max_tokens=256,
)
click.echo(f"\nLiteLLM: response from proxy {response}")
print("\n Making streaming request to proxy")
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
stream=True,
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "this is a test request, write a short poem",
}
],
stream=True,
)
for chunk in response:
click.echo(f'LiteLLM: streaming response from proxy {chunk}')
click.echo(f"LiteLLM: streaming response from proxy {chunk}")
print("\n making completion request to proxy")
response = client.completions.create(model="gpt-3.5-turbo", prompt='this is a test request, write a short poem')
response = client.completions.create(
model="gpt-3.5-turbo", prompt="this is a test request, write a short poem"
)
print(response)
return
else:
if headers:
headers = json.loads(headers)
save_worker_config(model=model, alias=alias, api_base=api_base, api_version=api_version, debug=debug, temperature=temperature, max_tokens=max_tokens, request_timeout=request_timeout, max_budget=max_budget, telemetry=telemetry, drop_params=drop_params, add_function_to_prompt=add_function_to_prompt, headers=headers, save=save, config=config, use_queue=use_queue)
save_worker_config(
model=model,
alias=alias,
api_base=api_base,
api_version=api_version,
debug=debug,
temperature=temperature,
max_tokens=max_tokens,
request_timeout=request_timeout,
max_budget=max_budget,
telemetry=telemetry,
drop_params=drop_params,
add_function_to_prompt=add_function_to_prompt,
headers=headers,
save=save,
config=config,
use_queue=use_queue,
)
try:
import uvicorn
except:
raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`")
raise ImportError(
"Uvicorn needs to be imported. Run - `pip install uvicorn`"
)
if port == 8000 and is_port_in_use(port):
port = random.randint(1024, 49152)
uvicorn.run("litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers)
uvicorn.run(
"litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers
)
if __name__ == "__main__":

File diff suppressed because it is too large Load diff

View file

@ -1,71 +1,77 @@
from dotenv import load_dotenv
load_dotenv()
import json, subprocess
import psutil # Import the psutil library
import atexit
try:
### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both
from celery import Celery
import redis
except:
import sys
# from dotenv import load_dotenv
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"redis",
"celery"
]
)
# load_dotenv()
# import json, subprocess
# import psutil # Import the psutil library
# import atexit
import time
import sys, os
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path - for litellm local dev
import litellm
# try:
# ### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both
# from celery import Celery
# import redis
# except:
# import sys
# Redis connection setup
pool = redis.ConnectionPool(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"), db=0, max_connections=5)
redis_client = redis.Redis(connection_pool=pool)
# subprocess.check_call([sys.executable, "-m", "pip", "install", "redis", "celery"])
# Celery setup
celery_app = Celery('tasks', broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}", backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}")
celery_app.conf.update(
broker_pool_limit = None,
broker_transport_options = {'connection_pool': pool},
result_backend_transport_options = {'connection_pool': pool},
)
# import time
# import sys, os
# sys.path.insert(
# 0, os.path.abspath("../../..")
# ) # Adds the parent directory to the system path - for litellm local dev
# import litellm
# # Redis connection setup
# pool = redis.ConnectionPool(
# host=os.getenv("REDIS_HOST"),
# port=os.getenv("REDIS_PORT"),
# password=os.getenv("REDIS_PASSWORD"),
# db=0,
# max_connections=5,
# )
# redis_client = redis.Redis(connection_pool=pool)
# # Celery setup
# celery_app = Celery(
# "tasks",
# broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}",
# backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}",
# )
# celery_app.conf.update(
# broker_pool_limit=None,
# broker_transport_options={"connection_pool": pool},
# result_backend_transport_options={"connection_pool": pool},
# )
# Celery task
@celery_app.task(name='process_job', max_retries=3)
def process_job(*args, **kwargs):
try:
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore
response = llm_router.completion(*args, **kwargs) # type: ignore
if isinstance(response, litellm.ModelResponse):
response = response.model_dump_json()
return json.loads(response)
return str(response)
except Exception as e:
raise e
# # Celery task
# @celery_app.task(name="process_job", max_retries=3)
# def process_job(*args, **kwargs):
# try:
# llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore
# response = llm_router.completion(*args, **kwargs) # type: ignore
# if isinstance(response, litellm.ModelResponse):
# response = response.model_dump_json()
# return json.loads(response)
# return str(response)
# except Exception as e:
# raise e
# Ensure Celery workers are terminated when the script exits
def cleanup():
try:
# Get a list of all running processes
for process in psutil.process_iter(attrs=['pid', 'name']):
# Check if the process is a Celery worker process
if process.info['name'] == 'celery':
print(f"Terminating Celery worker with PID {process.info['pid']}")
# Terminate the Celery worker process
psutil.Process(process.info['pid']).terminate()
except Exception as e:
print(f"Error during cleanup: {e}")
# Register the cleanup function to run when the script exits
atexit.register(cleanup)
# # Ensure Celery workers are terminated when the script exits
# def cleanup():
# try:
# # Get a list of all running processes
# for process in psutil.process_iter(attrs=["pid", "name"]):
# # Check if the process is a Celery worker process
# if process.info["name"] == "celery":
# print(f"Terminating Celery worker with PID {process.info['pid']}")
# # Terminate the Celery worker process
# psutil.Process(process.info["pid"]).terminate()
# except Exception as e:
# print(f"Error during cleanup: {e}")
# # Register the cleanup function to run when the script exits
# atexit.register(cleanup)

View file

@ -1,12 +1,15 @@
import os
from multiprocessing import Process
def run_worker(cwd):
os.chdir(cwd)
os.system("celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info")
os.system(
"celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info"
)
def start_worker(cwd):
cwd += "/queue"
worker_process = Process(target=run_worker, args=(cwd,))
worker_process.start()

View file

@ -1,26 +1,34 @@
import sys, os
from dotenv import load_dotenv
load_dotenv()
# Add the path to the local folder to sys.path
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path - for litellm local dev
# import sys, os
# from dotenv import load_dotenv
def start_rq_worker():
from rq import Worker, Queue, Connection
from redis import Redis
# Set up RQ connection
redis_conn = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
print(redis_conn.ping()) # Should print True if connected successfully
# Create a worker and add the queue
try:
queue = Queue(connection=redis_conn)
worker = Worker([queue], connection=redis_conn)
except Exception as e:
print(f"Error setting up worker: {e}")
exit()
with Connection(redis_conn):
worker.work()
# load_dotenv()
# # Add the path to the local folder to sys.path
# sys.path.insert(
# 0, os.path.abspath("../../..")
# ) # Adds the parent directory to the system path - for litellm local dev
start_rq_worker()
# def start_rq_worker():
# from rq import Worker, Queue, Connection
# from redis import Redis
# # Set up RQ connection
# redis_conn = Redis(
# host=os.getenv("REDIS_HOST"),
# port=os.getenv("REDIS_PORT"),
# password=os.getenv("REDIS_PASSWORD"),
# )
# print(redis_conn.ping()) # Should print True if connected successfully
# # Create a worker and add the queue
# try:
# queue = Queue(connection=redis_conn)
# worker = Worker([queue], connection=redis_conn)
# except Exception as e:
# print(f"Error setting up worker: {e}")
# exit()
# with Connection(redis_conn):
# worker.work()
# start_rq_worker()

View file

@ -4,20 +4,18 @@ import uuid
import traceback
litellm_client = AsyncOpenAI(
api_key="test",
base_url="http://0.0.0.0:8000"
)
litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000")
async def litellm_completion():
# Your existing code for litellm_completion goes here
try:
response = await litellm_client.chat.completions.create(
response = await litellm_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"*180}], # this is about 4k tokens per request
messages=[
{"role": "user", "content": f"This is a test: {uuid.uuid4()}" * 180}
], # this is about 4k tokens per request
)
print(response)
return response
except Exception as e:
@ -25,7 +23,6 @@ async def litellm_completion():
with open("error_log.txt", "a") as error_log:
error_log.write(f"Error during completion: {str(e)}\n")
pass
async def main():
@ -45,6 +42,7 @@ async def main():
print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__":
# Blank out contents of error_log.txt
open("error_log.txt", "w").close()

View file

@ -4,16 +4,13 @@ import uuid
import traceback
litellm_client = AsyncOpenAI(
api_key="sk-1234",
base_url="http://0.0.0.0:8000"
)
litellm_client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:8000")
async def litellm_completion():
# Your existing code for litellm_completion goes here
try:
response = await litellm_client.chat.completions.create(
response = await litellm_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
)
@ -24,7 +21,6 @@ async def litellm_completion():
with open("error_log.txt", "a") as error_log:
error_log.write(f"Error during completion: {str(e)}\n")
pass
async def main():
@ -44,6 +40,7 @@ async def main():
print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__":
# Blank out contents of error_log.txt
open("error_log.txt", "w").close()

View file

@ -1,4 +1,4 @@
# test time it takes to make 100 concurrent embedding requests to OpenaI
# test time it takes to make 100 concurrent embedding requests to OpenaI
import sys, os
import traceback
@ -14,16 +14,16 @@ import pytest
import litellm
litellm.set_verbose=False
litellm.set_verbose = False
question = "embed this very long text" * 100
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
import concurrent.futures
import random
@ -35,7 +35,10 @@ def make_openai_completion(question):
try:
start_time = time.time()
import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY']) #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"]
) # base_url="http://0.0.0.0:8000",
response = client.embeddings.create(
model="text-embedding-ada-002",
input=[question],
@ -58,6 +61,7 @@ def make_openai_completion(question):
# )
return None
start_time = time.time()
# Number of concurrent calls (you can adjust this)
concurrent_calls = 500

View file

@ -4,24 +4,20 @@ import uuid
import traceback
litellm_client = AsyncOpenAI(
api_key="test",
base_url="http://0.0.0.0:8000"
)
litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000")
async def litellm_completion():
# Your existing code for litellm_completion goes here
try:
print("starting embedding calls")
response = await litellm_client.embeddings.create(
response = await litellm_client.embeddings.create(
model="text-embedding-ada-002",
input = [
"hello who are you" * 2000,
input=[
"hello who are you" * 2000,
"hello who are you tomorrow 1234" * 1000,
"hello who are you tomorrow 1234" * 1000
]
"hello who are you tomorrow 1234" * 1000,
],
)
print(response)
return response
@ -31,7 +27,6 @@ async def litellm_completion():
with open("error_log.txt", "a") as error_log:
error_log.write(f"Error during completion: {str(e)}\n")
pass
async def main():
@ -51,6 +46,7 @@ async def main():
print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__":
# Blank out contents of error_log.txt
open("error_log.txt", "w").close()

View file

@ -1,4 +1,4 @@
# test time it takes to make 100 concurrent embedding requests to OpenaI
# test time it takes to make 100 concurrent embedding requests to OpenaI
import sys, os
import traceback
@ -14,16 +14,16 @@ import pytest
import litellm
litellm.set_verbose=False
litellm.set_verbose = False
question = "embed this very long text" * 100
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
import concurrent.futures
import random
@ -35,7 +35,10 @@ def make_openai_completion(question):
try:
start_time = time.time()
import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000"
) # base_url="http://0.0.0.0:8000",
response = client.embeddings.create(
model="text-embedding-ada-002",
input=[question],
@ -58,6 +61,7 @@ def make_openai_completion(question):
# )
return None
start_time = time.time()
# Number of concurrent calls (you can adjust this)
concurrent_calls = 500

View file

@ -2,6 +2,7 @@ import requests
import time
import os
from dotenv import load_dotenv
load_dotenv()
@ -12,37 +13,35 @@ base_url = "https://api.litellm.ai"
# Step 1 Add a config to the proxy, generate a temp key
config = {
"model_list": [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.environ['OPENAI_API_KEY'],
}
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.environ['AZURE_API_KEY'],
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_version": "2023-07-01-preview"
}
}
]
"model_list": [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.environ["OPENAI_API_KEY"],
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.environ["AZURE_API_KEY"],
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_version": "2023-07-01-preview",
},
},
]
}
print("STARTING LOAD TEST Q")
print(os.environ['AZURE_API_KEY'])
print(os.environ["AZURE_API_KEY"])
response = requests.post(
url=f"{base_url}/key/generate",
json={
"config": config,
"duration": "30d" # default to 30d, set it to 30m if you want a temp key
"duration": "30d", # default to 30d, set it to 30m if you want a temp key
},
headers={
"Authorization": "Bearer sk-hosted-litellm"
}
headers={"Authorization": "Bearer sk-hosted-litellm"},
)
print("\nresponse from generating key", response.text)
@ -56,19 +55,18 @@ print("\ngenerated key for proxy", generated_key)
import concurrent.futures
def create_job_and_poll(request_num):
print(f"Creating a job on the proxy for request {request_num}")
job_response = requests.post(
url=f"{base_url}/queue/request",
json={
'model': 'gpt-3.5-turbo',
'messages': [
{'role': 'system', 'content': 'write a short poem'},
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "write a short poem"},
],
},
headers={
"Authorization": f"Bearer {generated_key}"
}
headers={"Authorization": f"Bearer {generated_key}"},
)
print(job_response.status_code)
print(job_response.text)
@ -84,12 +82,12 @@ def create_job_and_poll(request_num):
try:
print(f"\nPolling URL for request {request_num}", polling_url)
polling_response = requests.get(
url=polling_url,
headers={
"Authorization": f"Bearer {generated_key}"
}
url=polling_url, headers={"Authorization": f"Bearer {generated_key}"}
)
print(
f"\nResponse from polling url for request {request_num}",
polling_response.text,
)
print(f"\nResponse from polling url for request {request_num}", polling_response.text)
polling_response = polling_response.json()
status = polling_response.get("status", None)
if status == "finished":
@ -109,6 +107,7 @@ def create_job_and_poll(request_num):
except Exception as e:
print("got exception when polling", e)
# Number of requests
num_requests = 100
@ -118,4 +117,4 @@ with concurrent.futures.ThreadPoolExecutor(max_workers=num_requests) as executor
futures = [executor.submit(create_job_and_poll, i) for i in range(num_requests)]
# Wait for all futures to complete
concurrent.futures.wait(futures)
concurrent.futures.wait(futures)

View file

@ -1,4 +1,4 @@
# # This tests the litelm proxy
# # This tests the litelm proxy
# # it makes async Completion requests with streaming
# import openai
@ -8,14 +8,14 @@
# async def test_async_completion():
# response = await (
# model="gpt-3.5-turbo",
# model="gpt-3.5-turbo",
# prompt='this is a test request, write a short poem',
# )
# print(response)
# print("test_streaming")
# response = await openai.chat.completions.create(
# model="gpt-3.5-turbo",
# model="gpt-3.5-turbo",
# prompt='this is a test request, write a short poem',
# stream=True
# )
@ -26,4 +26,3 @@
# import asyncio
# asyncio.run(test_async_completion())

View file

@ -34,7 +34,3 @@
# response = claude_chat(messages)
# print(response)

View file

@ -2,6 +2,7 @@ import requests
import time
import os
from dotenv import load_dotenv
load_dotenv()
@ -12,26 +13,24 @@ base_url = "https://api.litellm.ai"
# Step 1 Add a config to the proxy, generate a temp key
config = {
"model_list": [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.environ['OPENAI_API_KEY'],
}
}
]
"model_list": [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.environ["OPENAI_API_KEY"],
},
}
]
}
response = requests.post(
url=f"{base_url}/key/generate",
json={
"config": config,
"duration": "30d" # default to 30d, set it to 30m if you want a temp key
"duration": "30d", # default to 30d, set it to 30m if you want a temp key
},
headers={
"Authorization": "Bearer sk-hosted-litellm"
}
headers={"Authorization": "Bearer sk-hosted-litellm"},
)
print("\nresponse from generating key", response.text)
@ -45,22 +44,23 @@ print("Creating a job on the proxy")
job_response = requests.post(
url=f"{base_url}/queue/request",
json={
'model': 'gpt-3.5-turbo',
'messages': [
{'role': 'system', 'content': f'You are a helpful assistant. What is your name'},
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": f"You are a helpful assistant. What is your name",
},
],
},
headers={
"Authorization": f"Bearer {generated_key}"
}
headers={"Authorization": f"Bearer {generated_key}"},
)
print(job_response.status_code)
print(job_response.text)
print("\nResponse from creating job", job_response.text)
job_response = job_response.json()
job_id = job_response["id"] # type: ignore
polling_url = job_response["url"] # type: ignore
polling_url = f"{base_url}{polling_url}"
job_id = job_response["id"] # type: ignore
polling_url = job_response["url"] # type: ignore
polling_url = f"{base_url}{polling_url}"
print("\nCreated Job, Polling Url", polling_url)
# Step 3: Poll the request
@ -68,16 +68,13 @@ while True:
try:
print("\nPolling URL", polling_url)
polling_response = requests.get(
url=polling_url,
headers={
"Authorization": f"Bearer {generated_key}"
}
url=polling_url, headers={"Authorization": f"Bearer {generated_key}"}
)
print("\nResponse from polling url", polling_response.text)
polling_response = polling_response.json()
status = polling_response.get("status", None) # type: ignore
status = polling_response.get("status", None) # type: ignore
if status == "finished":
llm_response = polling_response["result"] # type: ignore
llm_response = polling_response["result"] # type: ignore
print("LLM Response")
print(llm_response)
break

View file

@ -8,16 +8,19 @@ from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException, status
def print_verbose(print_statement):
if litellm.set_verbose:
print(f"LiteLLM Proxy: {print_statement}") # noqa
### LOGGING ###
class ProxyLogging:
print(f"LiteLLM Proxy: {print_statement}") # noqa
### LOGGING ###
class ProxyLogging:
"""
Logging/Custom Handlers for proxy.
Logging/Custom Handlers for proxy.
Implemented mainly to:
- log successful/failed db read/writes
- log successful/failed db read/writes
- support the max parallel request integration
"""
@ -25,15 +28,15 @@ class ProxyLogging:
## INITIALIZE LITELLM CALLBACKS ##
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
self.max_budget_limiter = MaxBudgetLimiter()
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
self.max_budget_limiter = MaxBudgetLimiter()
pass
def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_budget_limiter)
for callback in litellm.callbacks:
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
if callback not in litellm.success_callback:
@ -44,7 +47,7 @@ class ProxyLogging:
litellm._async_success_callback.append(callback)
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
@ -57,31 +60,41 @@ class ProxyLogging:
+ litellm.failure_callback
)
)
litellm.utils.set_callbacks(
callback_list=callback_list
)
litellm.utils.set_callbacks(callback_list=callback_list)
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["completion", "embeddings"],
):
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
Covers:
Covers:
1. /chat/completions
2. /embeddings
2. /embeddings
"""
try:
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__):
response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type)
if response is not None:
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
callback.__class__
):
response = await callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
data=data,
call_type=call_type,
)
if response is not None:
data = response
print_verbose(f'final data being sent to {call_type} call: {data}')
print_verbose(f"final data being sent to {call_type} call: {data}")
return data
except Exception as e:
raise e
async def success_handler(self, *args, **kwargs):
async def success_handler(self, *args, **kwargs):
"""
Log successful db read/writes
"""
@ -93,26 +106,31 @@ class ProxyLogging:
Currently only logs exceptions to sentry
"""
if litellm.utils.capture_exception:
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
async def post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
"""
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
Covers:
Covers:
1. /chat/completions
2. /embeddings
2. /embeddings
"""
for callback in litellm.callbacks:
try:
try:
if isinstance(callback, CustomLogger):
await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception)
except Exception as e:
await callback.async_post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=original_exception,
)
except Exception as e:
raise e
return
### DB CONNECTOR ###
# Define the retry decorator with backoff strategy
@ -121,9 +139,12 @@ def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries
print_verbose(f"Backing off... this was attempt #{details['tries']}")
class PrismaClient:
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
## init logging object
self.proxy_logging_obj = proxy_logging_obj
@ -136,23 +157,24 @@ class PrismaClient:
os.chdir(dname)
try:
subprocess.run(['prisma', 'generate'])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
subprocess.run(["prisma", "generate"])
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"]
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
finally:
os.chdir(original_dir)
# Now you can import the Prisma Client
from prisma import Client # type: ignore
self.db = Client() #Client to connect to Prisma db
from prisma import Client # type: ignore
self.db = Client() # Client to connect to Prisma db
def hash_token(self, token: str):
# Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest()
return hashed_token
def jsonify_object(self, data: dict) -> dict:
def jsonify_object(self, data: dict) -> dict:
db_data = copy.deepcopy(data)
for k, v in db_data.items():
@ -162,233 +184,258 @@ class PrismaClient:
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None):
try:
async def get_data(
self,
token: Optional[str] = None,
expires: Optional[Any] = None,
user_id: Optional[str] = None,
):
try:
response = None
if token is not None:
if token is not None:
# check if plain text or hash
hashed_token = token
if token.startswith("sk-"):
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token
}
)
where={"token": hashed_token}
)
if response:
# Token exists, now check expiration.
if response.expires is not None and expires is not None:
if response.expires is not None and expires is not None:
if response.expires >= expires:
# Token exists and is not expired.
return response
else:
# Token exists but is expired.
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key",
)
return response
else:
# Token does not exist.
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key")
elif user_id is not None:
response = await self.db.litellm_usertable.find_unique( # type: ignore
where={
"user_id": user_id,
}
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid user key",
)
elif user_id is not None:
response = await self.db.litellm_usertable.find_unique( # type: ignore
where={
"user_id": user_id,
}
)
return response
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(self, data: dict):
"""
Add a key to the database. If it already exists, do nothing.
Add a key to the database. If it already exists, do nothing.
"""
try:
try:
token = data["token"]
hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token
max_budget = db_data.pop("max_budget", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
"token": hashed_token,
},
data={
"create": {**db_data}, #type: ignore
"update": {} # don't do anything if it already exists
}
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
new_user_row = await self.db.litellm_usertable.upsert(
where={
'user_id': data['user_id']
},
where={"user_id": data["user_id"]},
data={
"create": {"user_id": data['user_id'], "max_budget": max_budget},
"update": {} # don't do anything if it already exists
}
"create": {"user_id": data["user_id"], "max_budget": max_budget},
"update": {}, # don't do anything if it already exists
},
)
return new_verification_token
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None):
async def update_data(
self,
token: Optional[str] = None,
data: dict = {},
user_id: Optional[str] = None,
):
"""
Update existing data
"""
try:
try:
db_data = self.jsonify_object(data=data)
if token is not None:
if token is not None:
print_verbose(f"token: {token}")
# check if plain text or hash
if token.startswith("sk-"):
if token.startswith("sk-"):
token = self.hash_token(token=token)
db_data["token"] = token
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={
"token": token # type: ignore
},
data={**db_data} # type: ignore
where={"token": token}, # type: ignore
data={**db_data}, # type: ignore
)
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": db_data}
elif user_id is not None:
elif user_id is not None:
"""
If data['spend'] + data['user'], update the user table with spend info as well
"""
update_user_row = await self.db.litellm_usertable.update(
where={
'user_id': user_id # type: ignore
},
data={**db_data} # type: ignore
where={"user_id": user_id}, # type: ignore
data={**db_data}, # type: ignore
)
return {"user_id": user_id, "data": db_data}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def delete_data(self, tokens: List):
"""
Allow user to delete a key(s)
"""
try:
try:
hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
where={"token": {"in": hashed_tokens}}
)
return {"deleted_keys": tokens}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def connect(self):
try:
await self.db.connect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def disconnect(self):
async def connect(self):
try:
await self.db.connect()
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def disconnect(self):
try:
await self.db.disconnect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try:
print_verbose(f"value: {value}")
# Split the path by dots to separate module from instance
parts = value.split(".")
# The module path is all but the last part, and the instance_name is the last part
module_name = ".".join(parts[:-1])
instance_name = parts[-1]
# If config_file_path is provided, use it to determine the module spec and load the module
if config_file_path is not None:
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, *module_name.split('.'))
module_file_path += '.py'
module_file_path = os.path.join(directory, *module_name.split("."))
module_file_path += ".py"
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
if spec is None:
raise ImportError(f"Could not find a module specification for {module_file_path}")
raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
spec.loader.exec_module(module) # type: ignore
else:
# Dynamically import the module
module = importlib.import_module(module_name)
# Get the instance from the module
instance = getattr(module, instance_name)
return instance
except ImportError as e:
# Re-raise the exception with a user-friendly message
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
except Exception as e:
except Exception as e:
raise e
### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
"""
Check if a user_id exists in cache,
if not retrieve it.
Check if a user_id exists in cache,
if not retrieve it.
"""
cache_key = f"{user_id}_user_api_key_user_id"
response = cache.get_cache(key=cache_key)
if response is None: # Cache miss
if response is None: # Cache miss
user_row = await db.get_data(user_id=user_id)
cache_value = user_row.model_dump_json()
cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes
return
cache.set_cache(
key=cache_key, value=cache_value, ttl=600
) # store for 10 minutes
return

File diff suppressed because it is too large Load diff

View file

@ -1,96 +1,121 @@
#### What this does ####
# identifies least busy deployment
# How is this achieved?
# How is this achieved?
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# - use litellm.success + failure callbacks to log when a request completed
# - use litellm.success + failure callbacks to log when a request completed
# - in get_available_deployment, for a given model group name -> pick based on traffic
import dotenv, os, requests
from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
class LeastBusyLoggingHandler(CustomLogger):
class LeastBusyLoggingHandler(CustomLogger):
def __init__(self, router_cache: DualCache):
self.router_cache = router_cache
self.mapping_deployment_to_id: dict = {}
self.mapping_deployment_to_id: dict = {}
def log_pre_api_call(self, model, messages, kwargs):
"""
Log when a model is being used.
Log when a model is being used.
Caching based on model group.
Caching based on model group.
"""
try:
if kwargs['litellm_params'].get('metadata') is None:
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
id = kwargs['litellm_params'].get('model_info', {}).get('id', None)
else:
deployment = kwargs["litellm_params"]["metadata"].get(
"deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if deployment is None or model_group is None or id is None:
return
# map deployment to id
self.mapping_deployment_to_id[deployment] = id
request_count_api_key = f"{model_group}_request_count"
# update cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = (
request_count_dict.get(deployment, 0) + 1
)
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e:
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs['litellm_params'].get('metadata') is None:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
else:
deployment = kwargs["litellm_params"]["metadata"].get(
"deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
if deployment is None or model_group is None:
return
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e:
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs['litellm_params'].get('metadata') is None:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None)
model_group = kwargs['litellm_params']['metadata'].get('model_group', None)
else:
deployment = kwargs["litellm_params"]["metadata"].get(
"deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
if deployment is None or model_group is None:
return
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict)
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e:
pass
def get_available_deployments(self, model_group: str):
request_count_api_key = f"{model_group}_request_count"
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {}
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
# map deployment to id
return_dict = {}
for key, value in request_count_dict.items():
return_dict[self.mapping_deployment_to_id[key]] = value
return return_dict
return return_dict

View file

@ -2,6 +2,7 @@
import pytest, sys, os
import importlib
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@ -11,24 +12,30 @@ import litellm
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown():
"""
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
"""
curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
import litellm
importlib.reload(litellm)
print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding
yield
def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name]
other_tests = [item for item in items if 'custom_logger' not in item.parent.name]
custom_logger_tests = [
item for item in items if "custom_logger" in item.parent.name
]
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
# Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name)
other_tests.sort(key=lambda x: x.name)
# Reorder the items list
items[:] = custom_logger_tests + other_tests
items[:] = custom_logger_tests + other_tests

View file

@ -4,6 +4,7 @@
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@ -11,53 +12,62 @@ import litellm
from litellm import Router
import concurrent
from dotenv import load_dotenv
load_dotenv()
model_list = [{ # list of model deployments
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000
}
model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 1000000,
"rpm": 9000,
},
]
kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],}
kwargs = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
}
def test_multiple_deployments_sync():
import concurrent, time
litellm.set_verbose=False
results = []
router = Router(model_list=model_list,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
routing_strategy="simple-shuffle",
set_verbose=True,
num_retries=1) # type: ignore
try:
for _ in range(3):
response = router.completion(**kwargs)
results.append(response)
print(results)
router.reset()
except Exception as e:
print(f"FAILED TEST!")
pytest.fail(f"An error occurred - {traceback.format_exc()}")
def test_multiple_deployments_sync():
import concurrent, time
litellm.set_verbose = False
results = []
router = Router(
model_list=model_list,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
routing_strategy="simple-shuffle",
set_verbose=True,
num_retries=1,
) # type: ignore
try:
for _ in range(3):
response = router.completion(**kwargs)
results.append(response)
print(results)
router.reset()
except Exception as e:
print(f"FAILED TEST!")
pytest.fail(f"An error occurred - {traceback.format_exc()}")
# test_multiple_deployments_sync()
@ -67,13 +77,15 @@ def test_multiple_deployments_parallel():
results = []
futures = {}
start_time = time.time()
router = Router(model_list=model_list,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
routing_strategy="simple-shuffle",
set_verbose=True,
num_retries=1) # type: ignore
router = Router(
model_list=model_list,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
routing_strategy="simple-shuffle",
set_verbose=True,
num_retries=1,
) # type: ignore
# Assuming you have an executor instance defined somewhere in your code
with concurrent.futures.ThreadPoolExecutor() as executor:
for _ in range(5):
@ -82,7 +94,11 @@ def test_multiple_deployments_parallel():
# Retrieve the results from the futures
while futures:
done, not_done = concurrent.futures.wait(futures.values(), timeout=10, return_when=concurrent.futures.FIRST_COMPLETED)
done, not_done = concurrent.futures.wait(
futures.values(),
timeout=10,
return_when=concurrent.futures.FIRST_COMPLETED,
)
for future in done:
try:
result = future.result()
@ -98,12 +114,14 @@ def test_multiple_deployments_parallel():
print(results)
print(f"ELAPSED TIME: {end_time - start_time}")
# Assuming litellm, router, and executor are defined somewhere in your code
# test_multiple_deployments_parallel()
def test_cooldown_same_model_name():
# users could have the same model with different api_base
# example
# example
# azure/chatgpt, api_base: 1234
# azure/chatgpt, api_base: 1235
# if 1234 fails, it should only cooldown 1234 and then try with 1235
@ -118,7 +136,7 @@ def test_cooldown_same_model_name():
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": "BAD_API_BASE",
"tpm": 90
"tpm": 90,
},
},
{
@ -128,7 +146,7 @@ def test_cooldown_same_model_name():
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"tpm": 0.000001
"tpm": 0.000001,
},
},
]
@ -140,17 +158,12 @@ def test_cooldown_same_model_name():
redis_port=int(os.getenv("REDIS_PORT")),
routing_strategy="simple-shuffle",
set_verbose=True,
num_retries=3
num_retries=3,
) # type: ignore
response = router.completion(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "hello this request will pass"
}
]
messages=[{"role": "user", "content": "hello this request will pass"}],
)
print(router.model_list)
model_ids = []
@ -159,10 +172,13 @@ def test_cooldown_same_model_name():
print("\n litellm model ids ", model_ids)
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
assert model_ids[0] != model_ids[1] # ensure both models have a uuid added, and they have different names
assert (
model_ids[0] != model_ids[1]
) # ensure both models have a uuid added, and they have different names
print("\ngot response\n", response)
except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {e}")
test_cooldown_same_model_name()

View file

@ -9,68 +9,70 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
## case 1: set_function_to_prompt not set
## case 1: set_function_to_prompt not set
def test_function_call_non_openai_model():
try:
try:
model = "claude-instant-1"
messages=[{"role": "user", "content": "what's the weather in sf?"}]
messages = [{"role": "user", "content": "what's the weather in sf?"}]
functions = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
]
response = litellm.completion(model=model, messages=messages, functions=functions)
pytest.fail(f'An error occurred')
except Exception as e:
response = litellm.completion(
model=model, messages=messages, functions=functions
)
pytest.fail(f"An error occurred")
except Exception as e:
print(e)
pass
test_function_call_non_openai_model()
## case 2: add_function_to_prompt set
## case 2: add_function_to_prompt set
def test_function_call_non_openai_model_litellm_mod_set():
litellm.add_function_to_prompt = True
try:
try:
model = "claude-instant-1"
messages=[{"role": "user", "content": "what's the weather in sf?"}]
messages = [{"role": "user", "content": "what's the weather in sf?"}]
functions = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
]
response = litellm.completion(model=model, messages=messages, functions=functions)
print(f'response: {response}')
except Exception as e:
pytest.fail(f'An error occurred {e}')
response = litellm.completion(
model=model, messages=messages, functions=functions
)
print(f"response: {response}")
except Exception as e:
pytest.fail(f"An error occurred {e}")
# test_function_call_non_openai_model_litellm_mod_set()
# test_function_call_non_openai_model_litellm_mod_set()

View file

@ -1,4 +1,3 @@
import sys, os
import traceback
from dotenv import load_dotenv
@ -8,7 +7,7 @@ import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
) # Adds the parent directory to the system path
import pytest, asyncio
import litellm
from litellm import embedding, completion, completion_cost, Timeout, acompletion
@ -20,18 +19,18 @@ import tempfile
litellm.num_retries = 3
litellm.cache = None
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
messages = [{"content": user_message, "role": "user"}]
def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + '/vertex_key.json'
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, 'r') as file:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
@ -55,13 +54,13 @@ def load_vertex_ai_credentials():
service_account_key_data["private_key"] = private_key
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file:
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary file
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
@pytest.mark.asyncio
async def get_response():
@ -89,43 +88,80 @@ def test_vertex_ai():
import random
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
litellm.set_verbose=False
test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
litellm.set_verbose = False
litellm.vertex_project = "reliablekeys"
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
try:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model
continue
print("making request", model)
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}], temperature=0.7)
response = completion(
model=model,
messages=[{"role": "user", "content": "hi"}],
temperature=0.7,
)
print("\nModel Response", response)
print(response)
assert type(response.choices[0].message.content) == str
assert len(response.choices[0].message.content) > 1
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_vertex_ai()
def test_vertex_ai_stream():
load_vertex_ai_credentials()
litellm.set_verbose=False
litellm.set_verbose = False
litellm.vertex_project = "reliablekeys"
import random
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
try:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model
continue
print("making request", model)
response = completion(model=model, messages=[{"role": "user", "content": "write 10 line code code for saying hi"}], stream=True)
response = completion(
model=model,
messages=[
{"role": "user", "content": "write 10 line code code for saying hi"}
],
stream=True,
)
completed_str = ""
for chunk in response:
print(chunk)
@ -137,47 +173,86 @@ def test_vertex_ai_stream():
assert len(completed_str) > 4
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_vertex_ai_stream()
# test_vertex_ai_stream()
@pytest.mark.asyncio
async def test_async_vertexai_response():
import random
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
print(f'model being tested in async call: {model}')
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model
continue
print(f"model being tested in async call: {model}")
if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model
continue
try:
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5)
response = await acompletion(
model=model, messages=messages, temperature=0.7, timeout=5
)
print(f"response: {response}")
except litellm.Timeout as e:
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_response())
@pytest.mark.asyncio
async def test_async_vertexai_streaming_response():
import random
load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro
test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
# our account does not have access to this model
continue
if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model
continue
try:
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True)
response = await acompletion(
model="gemini-pro",
messages=messages,
temperature=0.7,
timeout=5,
stream=True,
)
print(f"response: {response}")
complete_response = ""
async for chunk in response:
@ -185,44 +260,46 @@ async def test_async_vertexai_streaming_response():
complete_response += chunk.choices[0].delta.content
print(f"complete_response: {complete_response}")
assert len(complete_response) > 0
except litellm.Timeout as e:
except litellm.Timeout as e:
pass
except Exception as e:
print(e)
pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_streaming_response())
def test_gemini_pro_vision():
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
litellm.num_retries=0
litellm.num_retries = 0
resp = litellm.completion(
model = "vertex_ai/gemini-pro-vision",
model="vertex_ai/gemini-pro-vision",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
}
}
]
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
},
},
],
}
],
)
print(resp)
except Exception as e:
import traceback
traceback.print_exc()
raise e
# test_gemini_pro_vision()
@ -333,4 +410,4 @@ def test_gemini_pro_vision():
# import traceback
# traceback.print_exc()
# raise e
# test_gemini_pro_vision_async()
# test_gemini_pro_vision_async()

View file

@ -11,8 +11,10 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
from litellm import completion, acompletion, acreate
litellm.num_retries = 3
def test_sync_response():
litellm.set_verbose = False
user_message = "Hello, how are you?"
@ -20,35 +22,49 @@ def test_sync_response():
try:
response = completion(model="gpt-3.5-turbo", messages=messages, timeout=5)
print(f"response: {response}")
except litellm.Timeout as e:
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
# test_sync_response()
def test_sync_response_anyscale():
litellm.set_verbose = False
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
except litellm.Timeout as e:
response = completion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
timeout=5,
)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
# test_sync_response_anyscale()
def test_async_response_openai():
import asyncio
litellm.set_verbose = True
async def test_get_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages, timeout=5)
response = await acompletion(
model="gpt-3.5-turbo", messages=messages, timeout=5
)
print(f"response: {response}")
print(f"response ms: {response._response_ms}")
except litellm.Timeout as e:
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
@ -56,54 +72,75 @@ def test_async_response_openai():
asyncio.run(test_get_response())
# test_async_response_openai()
def test_async_response_azure():
import asyncio
litellm.set_verbose = True
async def test_get_response():
user_message = "What do you know?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY"))
response = await acompletion(
model="azure/gpt-turbo",
messages=messages,
base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"),
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
)
print(f"response: {response}")
except litellm.Timeout as e:
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response())
# test_async_response_azure()
def test_async_anyscale_response():
import asyncio
litellm.set_verbose = True
async def test_get_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
response = await acompletion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
timeout=5,
)
# response = await response
print(f"response: {response}")
except litellm.Timeout as e:
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response())
# test_async_anyscale_response()
def test_get_response_streaming():
import asyncio
async def test_async_call():
user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}]
try:
litellm.set_verbose = True
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5)
response = await acompletion(
model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5
)
print(type(response))
import inspect
@ -116,29 +153,39 @@ def test_get_response_streaming():
async for chunk in response:
token = chunk["choices"][0]["delta"].get("content", "")
if token == None:
continue # openai v1.0.0 returns content=None
continue # openai v1.0.0 returns content=None
output += token
assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0."
print(f'output: {output}')
except litellm.Timeout as e:
print(f"output: {output}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_async_call())
# test_get_response_streaming()
def test_get_response_non_openai_streaming():
import asyncio
litellm.set_verbose = True
litellm.num_retries = 0
async def test_async_call():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True, timeout=5)
response = await acompletion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
stream=True,
timeout=5,
)
print(type(response))
import inspect
@ -158,11 +205,13 @@ def test_get_response_non_openai_streaming():
assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0."
except litellm.Timeout as e:
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
return response
asyncio.run(test_async_call())
# test_get_response_non_openai_streaming()
# test_get_response_non_openai_streaming()

View file

@ -3,79 +3,105 @@
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
import openai, litellm, uuid
from openai import AsyncAzureOpenAI
client = AsyncAzureOpenAI(
api_key=os.getenv("AZURE_API_KEY"),
azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
api_version=os.getenv("AZURE_API_VERSION")
azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
api_version=os.getenv("AZURE_API_VERSION"),
)
model_list = [
{
"model_name": "azure-test",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION")
{
"model_name": "azure-test",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
},
}
}
]
router = litellm.Router(model_list=model_list)
async def _openai_completion():
try:
start_time = time.time()
response = await client.chat.completions.create(
model="chatgpt-v-2",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
stream=True
)
time_to_first_token = None
first_token_ts = None
init_chunk = None
async for chunk in response:
if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
first_token_ts = time.time()
time_to_first_token = first_token_ts - start_time
init_chunk = chunk
end_time = time.time()
print("OpenAI Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time)
return time_to_first_token
except Exception as e:
print(e)
return None
try:
start_time = time.time()
response = await client.chat.completions.create(
model="chatgpt-v-2",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
stream=True,
)
time_to_first_token = None
first_token_ts = None
init_chunk = None
async for chunk in response:
if (
time_to_first_token is None
and len(chunk.choices) > 0
and chunk.choices[0].delta.content is not None
):
first_token_ts = time.time()
time_to_first_token = first_token_ts - start_time
init_chunk = chunk
end_time = time.time()
print(
"OpenAI Call: ",
init_chunk,
start_time,
first_token_ts,
time_to_first_token,
end_time,
)
return time_to_first_token
except Exception as e:
print(e)
return None
async def _router_completion():
try:
start_time = time.time()
response = await router.acompletion(
model="azure-test",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
stream=True
)
time_to_first_token = None
first_token_ts = None
init_chunk = None
async for chunk in response:
if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
first_token_ts = time.time()
time_to_first_token = first_token_ts - start_time
init_chunk = chunk
end_time = time.time()
print("Router Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time - first_token_ts)
return time_to_first_token
except Exception as e:
print(e)
return None
try:
start_time = time.time()
response = await router.acompletion(
model="azure-test",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
stream=True,
)
time_to_first_token = None
first_token_ts = None
init_chunk = None
async for chunk in response:
if (
time_to_first_token is None
and len(chunk.choices) > 0
and chunk.choices[0].delta.content is not None
):
first_token_ts = time.time()
time_to_first_token = first_token_ts - start_time
init_chunk = chunk
end_time = time.time()
print(
"Router Call: ",
init_chunk,
start_time,
first_token_ts,
time_to_first_token,
end_time - first_token_ts,
)
return time_to_first_token
except Exception as e:
print(e)
return None
async def test_azure_completion_streaming():
async def test_azure_completion_streaming():
"""
Test azure streaming call - measure on time to first (non-null) token.
Test azure streaming call - measure on time to first (non-null) token.
"""
n = 3 # Number of concurrent tasks
## OPENAI AVG. TIME
@ -83,19 +109,20 @@ async def test_azure_completion_streaming():
chat_completions = await asyncio.gather(*tasks)
successful_completions = [c for c in chat_completions if c is not None]
total_time = 0
for item in successful_completions:
total_time += item
avg_openai_time = total_time/3
for item in successful_completions:
total_time += item
avg_openai_time = total_time / 3
## ROUTER AVG. TIME
tasks = [_router_completion() for _ in range(n)]
chat_completions = await asyncio.gather(*tasks)
successful_completions = [c for c in chat_completions if c is not None]
total_time = 0
for item in successful_completions:
total_time += item
avg_router_time = total_time/3
for item in successful_completions:
total_time += item
avg_router_time = total_time / 3
## COMPARE
print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}")
assert avg_router_time < avg_openai_time + 0.5
# asyncio.run(test_azure_completion_streaming())
# asyncio.run(test_azure_completion_streaming())

View file

@ -5,6 +5,7 @@
import sys, os
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@ -18,6 +19,7 @@ user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
model_val = None
def test_completion_with_no_model():
# test on empty
with pytest.raises(ValueError):
@ -32,9 +34,10 @@ def test_completion_with_empty_model():
print(f"error occurred: {e}")
pass
# def test_completion_catch_nlp_exception():
# TEMP commented out NLP cloud API is unstable
# try:
# try:
# response = completion(model="dolphin", messages=messages, functions=[
# {
# "name": "get_current_weather",
@ -56,65 +59,77 @@ def test_completion_with_empty_model():
# }
# ])
# except Exception as e:
# if "Function calling is not supported by nlp_cloud" in str(e):
# except Exception as e:
# if "Function calling is not supported by nlp_cloud" in str(e):
# pass
# else:
# pytest.fail(f'An error occurred {e}')
# test_completion_catch_nlp_exception()
# test_completion_catch_nlp_exception()
def test_completion_invalid_param_cohere():
try:
try:
response = completion(model="command-nightly", messages=messages, top_p=1)
print(f"response: {response}")
except Exception as e:
if "Unsupported parameters passed: top_p" in str(e):
except Exception as e:
if "Unsupported parameters passed: top_p" in str(e):
pass
else:
pytest.fail(f'An error occurred {e}')
else:
pytest.fail(f"An error occurred {e}")
# test_completion_invalid_param_cohere()
def test_completion_function_call_cohere():
try:
response = completion(model="command-nightly", messages=messages, functions=["TEST-FUNCTION"])
pytest.fail(f'An error occurred {e}')
except Exception as e:
try:
response = completion(
model="command-nightly", messages=messages, functions=["TEST-FUNCTION"]
)
pytest.fail(f"An error occurred {e}")
except Exception as e:
print(e)
pass
# test_completion_function_call_cohere()
def test_completion_function_call_openai():
try:
def test_completion_function_call_openai():
try:
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
response = completion(model="gpt-3.5-turbo", messages=messages, functions=[
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
response = completion(
model="gpt-3.5-turbo",
messages=messages,
functions=[
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
}
},
"required": ["location"]
}
}
])
],
)
print(f"response: {response}")
except:
except:
pass
# test_completion_function_call_openai()
# test_completion_function_call_openai()
def test_completion_with_no_provider():
# test on empty
@ -125,6 +140,7 @@ def test_completion_with_no_provider():
print(f"error occurred: {e}")
pass
# test_completion_with_no_provider()
# # bad key
# temp_key = os.environ.get("OPENAI_API_KEY")
@ -136,4 +152,4 @@ def test_completion_with_no_provider():
# except:
# print(f"error occurred: {traceback.format_exc()}")
# pass
# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5
# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5

View file

@ -4,62 +4,78 @@
import sys, os
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from openai import APITimeoutError as Timeout
import litellm
litellm.num_retries = 0
from litellm import batch_completion, batch_completion_models, completion, batch_completion_models_all_responses
from litellm import (
batch_completion,
batch_completion_models,
completion,
batch_completion_models_all_responses,
)
# litellm.set_verbose=True
def test_batch_completions():
messages = [[{"role": "user", "content": "write a short poem"}] for _ in range(3)]
model = "j2-mid"
litellm.set_verbose = True
try:
result = batch_completion(
model=model,
model=model,
messages=messages,
max_tokens=10,
temperature=0.2,
request_timeout=1
request_timeout=1,
)
print(result)
print(len(result))
assert(len(result)==3)
assert len(result) == 3
except Timeout as e:
print(f"IN TIMEOUT")
pass
except Exception as e:
pytest.fail(f"An error occurred: {e}")
test_batch_completions()
def test_batch_completions_models():
try:
result = batch_completion_models(
models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"],
messages=[{"role": "user", "content": "Hey, how's it going"}]
models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"],
messages=[{"role": "user", "content": "Hey, how's it going"}],
)
print(result)
except Timeout as e:
pass
except Exception as e:
pytest.fail(f"An error occurred: {e}")
# test_batch_completions_models()
def test_batch_completion_models_all_responses():
try:
responses = batch_completion_models_all_responses(
models=["j2-light", "claude-instant-1.2"],
models=["j2-light", "claude-instant-1.2"],
messages=[{"role": "user", "content": "write a poem"}],
max_tokens=10
max_tokens=10,
)
print(responses)
assert(len(responses) == 2)
assert len(responses) == 2
except Timeout as e:
pass
except Exception as e:
pytest.fail(f"An error occurred: {e}")
# test_batch_completion_models_all_responses()
# test_batch_completion_models_all_responses()

View file

@ -3,12 +3,12 @@
# import sys, os, json
# import traceback
# import pytest
# import pytest
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import litellm
# import litellm
# litellm.set_verbose = True
# from litellm import completion, BudgetManager
@ -16,7 +16,7 @@
# ## Scenario 1: User budget enough to make call
# def test_user_budget_enough():
# try:
# try:
# user = "1234"
# # create a budget for a user
# budget_manager.create_budget(total_budget=10, user=user, duration="daily")
@ -38,7 +38,7 @@
# ## Scenario 2: User budget not enough to make call
# def test_user_budget_not_enough():
# try:
# try:
# user = "12345"
# # create a budget for a user
# budget_manager.create_budget(total_budget=0, user=user, duration="daily")
@ -60,7 +60,7 @@
# except:
# pytest.fail(f"An error occurred")
# ## Scenario 3: Saving budget to client
# ## Scenario 3: Saving budget to client
# def test_save_user_budget():
# try:
# response = budget_manager.save_data()
@ -70,17 +70,17 @@
# except Exception as e:
# pytest.fail(f"An error occurred: {str(e)}")
# test_save_user_budget()
# ## Scenario 4: Getting list of users
# test_save_user_budget()
# ## Scenario 4: Getting list of users
# def test_get_users():
# try:
# response = budget_manager.get_users()
# print(response)
# except:
# pytest.fail(f"An error occurred")
# pytest.fail(f"An error occurred")
# ## Scenario 5: Reset budget at the end of duration
# ## Scenario 5: Reset budget at the end of duration
# def test_reset_on_duration():
# try:
# # First, set a short duration budget for a user
@ -100,7 +100,7 @@
# # Now, we need to simulate the passing of time. Since we don't want our tests to actually take days, we're going
# # to cheat a little -- we'll manually adjust the "created_at" time so it seems like a day has passed.
# # In a real-world testing scenario, we might instead use something like the `freezegun` library to mock the system time.
# # In a real-world testing scenario, we might instead use something like the `freezegun` library to mock the system time.
# one_day_in_seconds = 24 * 60 * 60
# budget_manager.user_dict[user]["last_updated_at"] -= one_day_in_seconds
@ -108,11 +108,11 @@
# budget_manager.update_budget_all_users()
# # Make sure the budget was actually reset
# assert budget_manager.get_current_cost(user) == 0, "Budget didn't reset after duration expired"
# assert budget_manager.get_current_cost(user) == 0, "Budget didn't reset after duration expired"
# except Exception as e:
# pytest.fail(f"An error occurred - {str(e)}")
# ## Scenario 6: passing in text:
# ## Scenario 6: passing in text:
# def test_input_text_on_completion():
# try:
# user = "12345"
@ -127,4 +127,4 @@
# except Exception as e:
# pytest.fail(f"An error occurred - {str(e)}")
# test_input_text_on_completion()
# test_input_text_on_completion()

View file

@ -14,6 +14,7 @@ import litellm
from litellm import embedding, completion
from litellm.caching import Cache
import random
# litellm.set_verbose=True
messages = [{"role": "user", "content": "who is ishaan Github? "}]
@ -22,23 +23,30 @@ messages = [{"role": "user", "content": "who is ishaan Github? "}]
import random
import string
def generate_random_word(length=4):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for _ in range(length))
return "".join(random.choice(letters) for _ in range(length))
messages = [{"role": "user", "content": "who is ishaan 5222"}]
def test_caching_v2(): # test in memory cache
def test_caching_v2(): # test in memory cache
try:
litellm.set_verbose=True
litellm.set_verbose = True
litellm.cache = Cache()
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
litellm.cache = None # disable cache
litellm.success_callback = []
litellm._async_success_callback = []
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
@ -46,12 +54,14 @@ def test_caching_v2(): # test in memory cache
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
# test_caching_v2()
def test_caching_with_models_v2():
messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}]
messages = [
{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}
]
litellm.cache = Cache()
print("test2 for caching")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
@ -63,34 +73,51 @@ def test_caching_with_models_v2():
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
if (
response3["choices"][0]["message"]["content"]
== response2["choices"][0]["message"]["content"]
):
# if models are different, it should not return cached response
print(f"response2: {response2}")
print(f"response3: {response3}")
pytest.fail(f"Error occurred:")
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']:
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
):
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
# test_caching_with_models_v2()
embedding_large_text = """
embedding_large_text = (
"""
small text
""" * 5
"""
* 5
)
# # test_caching_with_models()
def test_embedding_caching():
import time
litellm.cache = Cache()
text_to_embed = [embedding_large_text]
start_time = time.time()
embedding1 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True)
embedding1 = embedding(
model="text-embedding-ada-002", input=text_to_embed, caching=True
)
end_time = time.time()
print(f"Embedding 1 response time: {end_time - start_time} seconds")
time.sleep(1)
start_time = time.time()
embedding2 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True)
embedding2 = embedding(
model="text-embedding-ada-002", input=text_to_embed, caching=True
)
end_time = time.time()
print(f"embedding2: {embedding2}")
print(f"Embedding 2 response time: {end_time - start_time} seconds")
@ -98,29 +125,30 @@ def test_embedding_caching():
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]:
print(f"embedding1: {embedding1}")
print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed")
# test_embedding_caching()
def test_embedding_caching_azure():
print("Testing azure embedding caching")
import time
litellm.cache = Cache()
text_to_embed = [embedding_large_text]
api_key = os.environ['AZURE_API_KEY']
api_base = os.environ['AZURE_API_BASE']
api_version = os.environ['AZURE_API_VERSION']
os.environ['AZURE_API_VERSION'] = ""
os.environ['AZURE_API_BASE'] = ""
os.environ['AZURE_API_KEY'] = ""
api_key = os.environ["AZURE_API_KEY"]
api_base = os.environ["AZURE_API_BASE"]
api_version = os.environ["AZURE_API_VERSION"]
os.environ["AZURE_API_VERSION"] = ""
os.environ["AZURE_API_BASE"] = ""
os.environ["AZURE_API_KEY"] = ""
start_time = time.time()
print("AZURE CONFIGS")
@ -133,7 +161,7 @@ def test_embedding_caching_azure():
api_key=api_key,
api_base=api_base,
api_version=api_version,
caching=True
caching=True,
)
end_time = time.time()
print(f"Embedding 1 response time: {end_time - start_time} seconds")
@ -146,7 +174,7 @@ def test_embedding_caching_azure():
api_key=api_key,
api_base=api_base,
api_version=api_version,
caching=True
caching=True,
)
end_time = time.time()
print(f"Embedding 2 response time: {end_time - start_time} seconds")
@ -154,15 +182,16 @@ def test_embedding_caching_azure():
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]:
print(f"embedding1: {embedding1}")
print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed")
os.environ['AZURE_API_VERSION'] = api_version
os.environ['AZURE_API_BASE'] = api_base
os.environ['AZURE_API_KEY'] = api_key
os.environ["AZURE_API_VERSION"] = api_version
os.environ["AZURE_API_BASE"] = api_base
os.environ["AZURE_API_KEY"] = api_key
# test_embedding_caching_azure()
@ -170,13 +199,28 @@ def test_embedding_caching_azure():
def test_redis_cache_completion():
litellm.set_verbose = False
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test2 for caching")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20)
response3 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5)
response1 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
)
response2 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
)
response3 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5
)
response4 = completion(model="command-nightly", messages=messages, caching=True)
print("\nresponse 1", response1)
@ -192,49 +236,88 @@ def test_redis_cache_completion():
1 & 3 should be different, since input params are diff
1 & 4 should be diff, since models are diff
"""
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
): # 1 and 2 should be the same
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
if response1['choices'][0]['message']['content'] == response3['choices'][0]['message']['content']:
if (
response1["choices"][0]["message"]["content"]
== response3["choices"][0]["message"]["content"]
):
# if input params like seed, max_tokens are diff it should NOT be a cache hit
print(f"response1: {response1}")
print(f"response3: {response3}")
pytest.fail(f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:")
if response1['choices'][0]['message']['content'] == response4['choices'][0]['message']['content']:
pytest.fail(
f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:"
)
if (
response1["choices"][0]["message"]["content"]
== response4["choices"][0]["message"]["content"]
):
# if models are different, it should not return cached response
print(f"response1: {response1}")
print(f"response4: {response4}")
pytest.fail(f"Error occurred:")
# test_redis_cache_completion()
def test_redis_cache_completion_stream():
try:
litellm.success_callback = []
litellm._async_success_callback = []
litellm.callbacks = []
litellm.set_verbose = True
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion")
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=0.2,
stream=True,
)
response_1_content = ""
for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
time.sleep(0.5)
response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True)
response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=0.2,
stream=True,
)
response_2_content = ""
for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = []
litellm.cache = None
litellm.success_callback = []
@ -247,99 +330,171 @@ def test_redis_cache_completion_stream():
1 & 2 should be exactly the same
"""
# test_redis_cache_completion_stream()
def test_redis_cache_acompletion_stream():
import asyncio
try:
litellm.set_verbose = True
random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_word}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion")
response_1_content = ""
response_2_content = ""
async def call1():
nonlocal response_1_content
response1 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True)
nonlocal response_1_content
response1 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
asyncio.run(call1())
time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2():
nonlocal response_2_content
response2 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True)
response2 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
asyncio.run(call2())
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
print(e)
raise e
# test_redis_cache_acompletion_stream()
def test_redis_cache_acompletion_stream_bedrock():
import asyncio
try:
litellm.set_verbose = True
random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_word}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion")
response_1_content = ""
response_2_content = ""
async def call1():
nonlocal response_1_content
response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
nonlocal response_1_content
response1 = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1:
print(chunk)
response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content)
asyncio.run(call1())
time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2():
nonlocal response_2_content
response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
response2 = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2:
print(chunk)
response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content)
asyncio.run(call2())
print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
except Exception as e:
print(e)
raise e
# test_redis_cache_acompletion_stream_bedrock()
# redis cache with custom keys
def custom_get_cache_key(*args, **kwargs):
# return key to use for your cache:
key = kwargs.get("model", "") + str(kwargs.get("messages", "")) + str(kwargs.get("temperature", "")) + str(kwargs.get("logit_bias", ""))
# return key to use for your cache:
key = (
kwargs.get("model", "")
+ str(kwargs.get("messages", ""))
+ str(kwargs.get("temperature", ""))
+ str(kwargs.get("logit_bias", ""))
)
return key
def test_custom_redis_cache_with_key():
messages = [{"role": "user", "content": "write a one line story"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.cache.get_cache_key = custom_get_cache_key
local_cache = {}
@ -356,54 +511,72 @@ def test_custom_redis_cache_with_key():
# patch this redis cache get and set call
response1 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3)
response2 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3)
response3 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=False, num_retries=3)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
temperature=1,
caching=True,
num_retries=3,
)
response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
temperature=1,
caching=True,
num_retries=3,
)
response3 = completion(
model="gpt-3.5-turbo",
messages=messages,
temperature=1,
caching=False,
num_retries=3,
)
print(f"response1: {response1}")
print(f"response2: {response2}")
print(f"response3: {response3}")
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
if (
response3["choices"][0]["message"]["content"]
== response2["choices"][0]["message"]["content"]
):
pytest.fail(f"Error occurred:")
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# test_custom_redis_cache_with_key()
def test_cache_override():
# test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set
# in this case it should not return cached responses
# in this case it should not return cached responses
litellm.cache = Cache()
print("Testing cache override")
litellm.set_verbose=True
litellm.set_verbose = True
# test embedding
response1 = embedding(
model = "text-embedding-ada-002",
input=[
"hello who are you"
],
caching = False
model="text-embedding-ada-002", input=["hello who are you"], caching=False
)
start_time = time.time()
response2 = embedding(
model = "text-embedding-ada-002",
input=[
"hello who are you"
],
caching = False
model="text-embedding-ada-002", input=["hello who are you"], caching=False
)
end_time = time.time()
print(f"Embedding 2 response time: {end_time - start_time} seconds")
assert end_time - start_time > 0.1 # ensure 2nd response comes in over 0.1s. This should not be cached.
# test_cache_override()
assert (
end_time - start_time > 0.1
) # ensure 2nd response comes in over 0.1s. This should not be cached.
# test_cache_override()
def test_custom_redis_cache_params():
@ -411,17 +584,17 @@ def test_custom_redis_cache_params():
try:
litellm.cache = Cache(
type="redis",
host=os.environ['REDIS_HOST'],
port=os.environ['REDIS_PORT'],
password=os.environ['REDIS_PASSWORD'],
db = 0,
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
db=0,
ssl=True,
ssl_certfile="./redis_user.crt",
ssl_keyfile="./redis_user_private.key",
ssl_ca_certs="./redis_ca.pem",
)
print(litellm.cache.cache.redis_client)
print(litellm.cache.cache.redis_client)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
@ -431,58 +604,126 @@ def test_custom_redis_cache_params():
def test_get_cache_key():
from litellm.caching import Cache
try:
print("Testing get_cache_key")
cache_instance = Cache()
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
cache_key = cache_instance.get_cache_key(
**{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "write a one sentence poem about: 7510"}
],
"max_tokens": 40,
"temperature": 0.2,
"stream": True,
"litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b",
"litellm_logging_obj": {},
}
)
cache_key_2 = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
cache_key_2 = cache_instance.get_cache_key(
**{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "write a one sentence poem about: 7510"}
],
"max_tokens": 40,
"temperature": 0.2,
"stream": True,
"litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b",
"litellm_logging_obj": {},
}
)
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
assert cache_key == cache_key_2, f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
assert (
cache_key
== "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
)
assert (
cache_key == cache_key_2
), f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
embedding_cache_key = cache_instance.get_cache_key(
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
'api_key': '', 'api_version': '2023-07-01-preview',
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
'caching': True,
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>"
**{
"model": "azure/azure-embedding-model",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_key": "",
"api_version": "2023-07-01-preview",
"timeout": None,
"max_retries": 0,
"input": ["hi who is ishaan"],
"caching": True,
"client": "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
}
)
print(embedding_cache_key)
assert embedding_cache_key == "model: azure/azure-embedding-modelinput: ['hi who is ishaan']", f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
assert (
embedding_cache_key
== "model: azure/azure-embedding-modelinput: ['hi who is ishaan']"
), f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
# Proxy - embedding cache, test if embedding key, gets model_group and not model
embedding_cache_key_2 = cache_instance.get_cache_key(
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
'api_key': '', 'api_version': '2023-07-01-preview',
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
'caching': True,
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
'proxy_server_request': {'url': 'http://0.0.0.0:8000/embeddings',
'method': 'POST',
'headers':
{'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json',
'content-length': '80'},
'body': {'model': 'azure-embedding-model', 'input': ['hi who is ishaan']}},
'user': None,
'metadata': {'user_api_key': None,
'headers': {'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', 'content-length': '80'},
'model_group': 'EMBEDDING_MODEL_GROUP',
'deployment': 'azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview'},
'model_info': {'mode': 'embedding', 'base_model': 'text-embedding-ada-002', 'id': '20b2b515-f151-4dd5-a74f-2231e2f54e29'},
'litellm_call_id': '2642e009-b3cd-443d-b5dd-bb7d56123b0e', 'litellm_logging_obj': '<litellm.utils.Logging object at 0x12f1bddb0>'}
**{
"model": "azure/azure-embedding-model",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_key": "",
"api_version": "2023-07-01-preview",
"timeout": None,
"max_retries": 0,
"input": ["hi who is ishaan"],
"caching": True,
"client": "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
"proxy_server_request": {
"url": "http://0.0.0.0:8000/embeddings",
"method": "POST",
"headers": {
"host": "0.0.0.0:8000",
"user-agent": "curl/7.88.1",
"accept": "*/*",
"content-type": "application/json",
"content-length": "80",
},
"body": {
"model": "azure-embedding-model",
"input": ["hi who is ishaan"],
},
},
"user": None,
"metadata": {
"user_api_key": None,
"headers": {
"host": "0.0.0.0:8000",
"user-agent": "curl/7.88.1",
"accept": "*/*",
"content-type": "application/json",
"content-length": "80",
},
"model_group": "EMBEDDING_MODEL_GROUP",
"deployment": "azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview",
},
"model_info": {
"mode": "embedding",
"base_model": "text-embedding-ada-002",
"id": "20b2b515-f151-4dd5-a74f-2231e2f54e29",
},
"litellm_call_id": "2642e009-b3cd-443d-b5dd-bb7d56123b0e",
"litellm_logging_obj": "<litellm.utils.Logging object at 0x12f1bddb0>",
}
)
print(embedding_cache_key_2)
assert embedding_cache_key_2 == "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
assert (
embedding_cache_key_2
== "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
)
print("passed!")
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred:", e)
test_get_cache_key()
# test_custom_redis_cache_params()
@ -581,4 +822,4 @@ test_get_cache_key()
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
# time.sleep(2)
# assert cache.get_cache(cache_key="test_key") is None
# # test_in_memory_cache_with_ttl()
# # test_in_memory_cache_with_ttl()

View file

@ -1,5 +1,5 @@
#### What this tests ####
# This tests using caching w/ litellm which requires SSL=True
# This tests using caching w/ litellm which requires SSL=True
import sys, os
import time
@ -18,15 +18,26 @@ from litellm import embedding, completion, Router
from litellm.caching import Cache
messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
def test_caching_v2(): # test in memory cache
def test_caching_v2(): # test in memory cache
try:
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2")
litellm.cache = Cache(
type="redis",
host="os.environ/REDIS_HOST_2",
port="os.environ/REDIS_PORT_2",
password="os.environ/REDIS_PASSWORD_2",
ssl="os.environ/REDIS_SSL_2",
)
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
litellm.cache = None # disable cache
if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}")
print(f"response2: {response2}")
raise Exception()
@ -34,41 +45,57 @@ def test_caching_v2(): # test in memory cache
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
# test_caching_v2()
def test_caching_router():
"""
Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit.
Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit.
"""
try:
try:
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
}
]
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2")
router = Router(model_list=model_list,
routing_strategy="simple-shuffle",
set_verbose=False,
num_retries=1) # type: ignore
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
}
]
litellm.cache = Cache(
type="redis",
host="os.environ/REDIS_HOST_2",
port="os.environ/REDIS_PORT_2",
password="os.environ/REDIS_PASSWORD_2",
ssl="os.environ/REDIS_SSL_2",
)
router = Router(
model_list=model_list,
routing_strategy="simple-shuffle",
set_verbose=False,
num_retries=1,
) # type: ignore
response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages)
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}")
print(f"response2: {response2}")
litellm.cache = None # disable cache
assert response2['choices'][0]['message']['content'] == response1['choices'][0]['message']['content']
litellm.cache = None # disable cache
assert (
response2["choices"][0]["message"]["content"]
== response1["choices"][0]["message"]["content"]
)
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
# test_caching_router()
# test_caching_router()

View file

@ -8,7 +8,7 @@
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import litellm
# import asyncio
# import asyncio
# litellm.set_verbose = True
# from litellm import Router
@ -18,9 +18,9 @@
# # This enables response_model keyword
# # # from client.chat.completions.create
# # client = instructor.patch(Router(model_list=[{
# # "model_name": "gpt-3.5-turbo", # openai model name
# # "litellm_params": { # params for litellm completion/embedding call
# # "model": "azure/chatgpt-v-2",
# # "model_name": "gpt-3.5-turbo", # openai model name
# # "litellm_params": { # params for litellm completion/embedding call
# # "model": "azure/chatgpt-v-2",
# # "api_key": os.getenv("AZURE_API_KEY"),
# # "api_version": os.getenv("AZURE_API_VERSION"),
# # "api_base": os.getenv("AZURE_API_BASE")
@ -49,9 +49,9 @@
# from openai import AsyncOpenAI
# aclient = instructor.apatch(Router(model_list=[{
# "model_name": "gpt-3.5-turbo", # openai model name
# "litellm_params": { # params for litellm completion/embedding call
# "model": "azure/chatgpt-v-2",
# "model_name": "gpt-3.5-turbo", # openai model name
# "litellm_params": { # params for litellm completion/embedding call
# "model": "azure/chatgpt-v-2",
# "api_key": os.getenv("AZURE_API_KEY"),
# "api_version": os.getenv("AZURE_API_VERSION"),
# "api_base": os.getenv("AZURE_API_BASE")
@ -71,4 +71,4 @@
# )
# print(f"model: {model}")
# asyncio.run(main())
# asyncio.run(main())

File diff suppressed because it is too large Load diff

View file

@ -28,6 +28,7 @@ def logger_fn(user_model_dict):
# print(f"user_model_dict: {user_model_dict}")
pass
# normal call
def test_completion_custom_provider_model_name():
try:
@ -41,25 +42,31 @@ def test_completion_custom_provider_model_name():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# completion with num retries + impact on exception mapping
def test_completion_with_num_retries():
try:
response = completion(model="j2-ultra", messages=[{"messages": "vibe", "bad": "message"}], num_retries=2)
def test_completion_with_num_retries():
try:
response = completion(
model="j2-ultra",
messages=[{"messages": "vibe", "bad": "message"}],
num_retries=2,
)
pytest.fail(f"Unmapped exception occurred")
except Exception as e:
except Exception as e:
pass
# test_completion_with_num_retries()
def test_completion_with_0_num_retries():
try:
litellm.set_verbose=False
litellm.set_verbose = False
print("making request")
# Use the completion function
response = completion(
model="gpt-3.5-turbo",
messages=[{"gm": "vibe", "role": "user"}],
max_retries=4
max_retries=4,
)
print(response)
@ -69,5 +76,6 @@ def test_completion_with_0_num_retries():
print("exception", e)
pass
# Call the test function
test_completion_with_0_num_retries()
test_completion_with_0_num_retries()

View file

@ -15,77 +15,104 @@ from litellm import completion_with_config
config = {
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
"model": {
"claude-instant-1": {
"needs_moderation": True
},
"claude-instant-1": {"needs_moderation": True},
"gpt-3.5-turbo": {
"error_handling": {
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
}
}
}
},
},
}
def test_config_context_window_exceeded():
try:
sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}]
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config)
response = completion_with_config(
model="gpt-3.5-turbo", messages=messages, config=config
)
print(response)
except Exception as e:
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
# test_config_context_window_exceeded()
# test_config_context_window_exceeded()
def test_config_context_moderation():
try:
messages=[{"role": "user", "content": "I want to kill them."}]
response = completion_with_config(model="claude-instant-1", messages=messages, config=config)
messages = [{"role": "user", "content": "I want to kill them."}]
response = completion_with_config(
model="claude-instant-1", messages=messages, config=config
)
print(response)
except Exception as e:
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
# test_config_context_moderation()
# test_config_context_moderation()
def test_config_context_default_fallback():
try:
messages=[{"role": "user", "content": "Hey, how's it going?"}]
response = completion_with_config(model="claude-instant-1", messages=messages, config=config, api_key="bad-key")
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = completion_with_config(
model="claude-instant-1",
messages=messages,
config=config,
api_key="bad-key",
)
print(response)
except Exception as e:
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
# test_config_context_default_fallback()
# test_config_context_default_fallback()
config = {
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
"available_models": ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-0314", "gpt-4-0613",
"j2-ultra", "command-nightly", "togethercomputer/llama-2-70b-chat", "chat-bison", "chat-bison@001", "claude-2"],
"adapt_to_prompt_size": True, # type: ignore
"available_models": [
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"j2-ultra",
"command-nightly",
"togethercomputer/llama-2-70b-chat",
"chat-bison",
"chat-bison@001",
"claude-2",
],
"adapt_to_prompt_size": True, # type: ignore
"model": {
"claude-instant-1": {
"needs_moderation": True
},
"claude-instant-1": {"needs_moderation": True},
"gpt-3.5-turbo": {
"error_handling": {
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
}
}
}
},
},
}
def test_config_context_adapt_to_prompt():
try:
sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}]
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config)
response = completion_with_config(
model="gpt-3.5-turbo", messages=messages, config=config
)
print(response)
except Exception as e:
print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}")
test_config_context_adapt_to_prompt()
test_config_context_adapt_to_prompt()

View file

@ -1,14 +1,16 @@
from litellm.proxy._types import UserAPIKeyAuth
from fastapi import Request
from dotenv import load_dotenv
import os
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
print(f"api_key: {api_key}")
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception
except:
raise Exception

View file

@ -2,30 +2,35 @@ from litellm.integrations.custom_logger import CustomLogger
import inspect
import litellm
class testCustomCallbackProxy(CustomLogger):
def __init__(self):
self.success: bool = False # type: ignore
self.failure: bool = False # type: ignore
self.async_success: bool = False # type: ignore
self.success: bool = False # type: ignore
self.failure: bool = False # type: ignore
self.async_success: bool = False # type: ignore
self.async_success_embedding: bool = False # type: ignore
self.async_failure: bool = False # type: ignore
self.async_failure: bool = False # type: ignore
self.async_failure_embedding: bool = False # type: ignore
self.async_completion_kwargs = None # type: ignore
self.async_embedding_kwargs = None # type: ignore
self.async_embedding_response = None # type: ignore
self.async_completion_kwargs = None # type: ignore
self.async_embedding_kwargs = None # type: ignore
self.async_embedding_response = None # type: ignore
self.async_completion_kwargs_fail = None # type: ignore
self.async_embedding_kwargs_fail = None # type: ignore
self.async_completion_kwargs_fail = None # type: ignore
self.async_embedding_kwargs_fail = None # type: ignore
self.streaming_response_obj = None # type: ignore
self.streaming_response_obj = None # type: ignore
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
print(f"{blue_color_code}Initialized LiteLLM custom logger")
try:
print(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
methods = [
method
for method in dir(self)
if inspect.ismethod(getattr(self, method))
]
# Pretty print the methods
for method in methods:
print(f" - {method}")
@ -33,29 +38,32 @@ class testCustomCallbackProxy(CustomLogger):
except:
pass
def log_pre_api_call(self, model, messages, kwargs):
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
self.success = True
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure")
self.failure = True
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async success")
self.async_success = True
print("Value of async success: ", self.async_success)
print("\n kwargs: ", kwargs)
if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada":
if (
kwargs.get("model") == "azure-embedding-model"
or kwargs.get("model") == "ada"
):
print("Got an embedding model", kwargs.get("model"))
print("Setting embedding success to True")
self.async_success_embedding = True
@ -65,7 +73,6 @@ class testCustomCallbackProxy(CustomLogger):
if kwargs.get("stream") == True:
self.streaming_response_obj = response_obj
self.async_completion_kwargs = kwargs
model = kwargs.get("model", None)
@ -74,17 +81,18 @@ class testCustomCallbackProxy(CustomLogger):
# Access litellm_params passed to litellm.completion(), example access `metadata`
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here
metadata = litellm_params.get(
"metadata", {}
) # headers passed to LiteLLM proxy, can be found here
# Calculate cost using litellm.completion_cost()
cost = litellm.completion_cost(completion_response=response_obj)
response = response_obj
# tokens used in response
# tokens used in response
usage = response_obj["usage"]
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
print(
f"""
Model: {model},
@ -98,8 +106,7 @@ class testCustomCallbackProxy(CustomLogger):
)
return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure")
self.async_failure = True
print("Value of async failure: ", self.async_failure)
@ -107,7 +114,8 @@ class testCustomCallbackProxy(CustomLogger):
if kwargs.get("model") == "text-embedding-ada-002":
self.async_failure_embedding = True
self.async_embedding_kwargs_fail = kwargs
self.async_completion_kwargs_fail = kwargs
my_custom_logger = testCustomCallbackProxy()
my_custom_logger = testCustomCallbackProxy()

File diff suppressed because it is too large Load diff

View file

@ -3,7 +3,8 @@
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List
from litellm import Router, Cache
import litellm
@ -14,206 +15,274 @@ from litellm.integrations.custom_logger import CustomLogger
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
## fallbacks
## retries
## fallbacks
## retries
# Test cases
## 1. Simple Azure OpenAI acompletion + streaming call
## 2. Simple Azure OpenAI aembedding call
# Test cases
## 1. Simple Azure OpenAI acompletion + streaming call
## 2. Simple Azure OpenAI aembedding call
## 3. Azure OpenAI acompletion + streaming call with retries
## 4. Azure OpenAI aembedding call with retries
## 5. Azure OpenAI acompletion + streaming call with fallbacks
## 6. Azure OpenAI aembedding call with fallbacks
# Test interfaces
## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings
## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
class CompletionCustomHandler(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
self.states: Optional[
List[
Literal[
"sync_pre_api_call",
"async_pre_api_call",
"post_api_call",
"sync_stream",
"async_stream",
"sync_success",
"async_success",
"sync_failure",
"async_failure",
]
]
] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
print(f'received kwargs in pre-input: {kwargs}')
def log_pre_api_call(self, model, messages, kwargs):
try:
print(f"received kwargs in pre-input: {kwargs}")
self.states.append("sync_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(
kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except Exception as e:
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("post_api_call")
## START TIME
## START TIME
assert isinstance(start_time, datetime)
## END TIME
## END TIME
assert end_time == None
## RESPONSE OBJECT
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.iscoroutine(kwargs["original_response"])
or inspect.isasyncgen(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(
kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except:
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_stream")
## START TIME
## START TIME
assert isinstance(start_time, datetime)
## END TIME
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_success")
## START TIME
## START TIME
assert isinstance(start_time, datetime)
## END TIME
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_failure")
## START TIME
## START TIME
assert isinstance(start_time, datetime)
## END TIME
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
try:
"""
No-op.
Not implemented yet.
No-op.
Not implemented yet.
"""
pass
except Exception as e:
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
try:
self.states.append("async_success")
## START TIME
## START TIME
assert isinstance(start_time, datetime)
## END TIME
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
## RESPONSE OBJECT
assert isinstance(
response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)
)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
@ -221,10 +290,14 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
assert isinstance(
kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except:
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -232,257 +305,281 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
try:
print(f"received original response: {kwargs['original_response']}")
self.states.append("async_failure")
## START TIME
## START TIME
assert isinstance(start_time, datetime)
## END TIME
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], (str, type(None)))
assert isinstance(kwargs['input'], (list, str, dict))
assert isinstance(kwargs['api_key'], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None
assert isinstance(kwargs['additional_args'], (dict, type(None)))
assert isinstance(kwargs['log_event_type'], str)
except:
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, str, dict))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
# Simple Azure OpenAI call
# Simple Azure OpenAI call
## COMPLETION
@pytest.mark.asyncio
async def test_async_chat_azure():
try:
try:
customHandler_completion_azure_router = CompletionCustomHandler()
customHandler_streaming_azure_router = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_completion_azure_router]
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
)
await asyncio.sleep(2)
assert len(customHandler_completion_azure_router.errors) == 0
assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success
# streaming
assert (
len(customHandler_completion_azure_router.states) == 3
) # pre, post, success
# streaming
litellm.callbacks = [customHandler_streaming_azure_router]
router2 = Router(model_list=model_list) # type: ignore
response = await router2.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
router2 = Router(model_list=model_list) # type: ignore
response = await router2.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True,
)
async for chunk in response:
print(f"async azure router chunk: {chunk}")
continue
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
assert len(customHandler_streaming_azure_router.errors) == 0
assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success
# failure
assert (
len(customHandler_streaming_azure_router.states) >= 4
) # pre, post, stream (multiple times), success
# failure
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
)
print(f"response in router3 acompletion: {response}")
except:
pass
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
assert len(customHandler_failure.states) == 3 # pre, post, failure
assert "async_failure" in customHandler_failure.states
except Exception as e:
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure())
## EMBEDDING
@pytest.mark.asyncio
async def test_async_embedding_azure():
try:
try:
customHandler = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler]
model_list = [
{
"model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.aembedding(model="azure-embedding-model",
input=["hello from litellm!"])
{
"model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.aembedding(
model="azure-embedding-model", input=["hello from litellm!"]
)
await asyncio.sleep(2)
assert len(customHandler.errors) == 0
assert len(customHandler.states) == 3 # pre, post, success
# failure
assert len(customHandler.states) == 3 # pre, post, success
# failure
model_list = [
{
"model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
{
"model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.aembedding(model="azure-embedding-model",
input=["hello from litellm!"])
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.aembedding(
model="azure-embedding-model", input=["hello from litellm!"]
)
print(f"response in router3 aembedding: {response}")
except:
pass
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
assert len(customHandler_failure.states) == 3 # pre, post, failure
assert "async_failure" in customHandler_failure.states
except Exception as e:
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_embedding_azure())
# Azure OpenAI call w/ Fallbacks
## COMPLETION
@pytest.mark.asyncio
async def test_async_chat_azure_with_fallbacks():
try:
async def test_async_chat_azure_with_fallbacks():
try:
customHandler_fallbacks = CompletionCustomHandler()
litellm.callbacks = [customHandler_fallbacks]
# with fallbacks
# with fallbacks
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo-16k",
"litellm_params": {
"model": "gpt-3.5-turbo-16k",
},
"tpm": 240000,
"rpm": 1800
}
]
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo-16k",
"litellm_params": {
"model": "gpt-3.5-turbo-16k",
},
"tpm": 240000,
"rpm": 1800,
},
]
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
)
await asyncio.sleep(2)
print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}")
assert len(customHandler_fallbacks.errors) == 0
assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success
assert (
len(customHandler_fallbacks.states) == 6
) # pre, post, failure, pre, post, success
litellm.callbacks = []
except Exception as e:
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure_with_fallbacks())
# CACHING
# CACHING
## Test Azure - completion, embedding
@pytest.mark.asyncio
async def test_async_completion_azure_caching():
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
{
"model_name": "gpt-3.5-turbo-16k",
"litellm_params": {
"model": "gpt-3.5-turbo-16k",
},
"tpm": 240000,
"rpm": 1800
}
]
router = Router(model_list=model_list) # type: ignore
response1 = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{
"model_name": "gpt-3.5-turbo-16k",
"litellm_params": {
"model": "gpt-3.5-turbo-16k",
},
"tpm": 240000,
"rpm": 1800,
},
]
router = Router(model_list=model_list) # type: ignore
response1 = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
caching=True,
)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": f"Hi 👋 - i'm async azure {unique_time}"
}],
caching=True)
await asyncio.sleep(1) # success callbacks are done in parallel
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
response2 = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
caching=True,
)
await asyncio.sleep(1) # success callbacks are done in parallel
print(
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
assert len(customHandler_caching.states) == 4 # pre, post, success, success

View file

@ -1,12 +1,14 @@
import sys
import os
import io, asyncio
# import logging
# logging.basicConfig(level=logging.DEBUG)
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion
import litellm
litellm.num_retries = 3
import time, random
@ -29,11 +31,14 @@ def pre_request():
import re
def verify_log_file(log_file_path):
with open(log_file_path, 'r') as log_file:
def verify_log_file(log_file_path):
with open(log_file_path, "r") as log_file:
log_content = log_file.read()
print(f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content)
print(
f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content
)
# Define the pattern to search for in the log file
pattern = r"Response from DynamoDB:{.*?}"
@ -50,17 +55,21 @@ def verify_log_file(log_file_path):
print(f"Total occurrences of specified response: {len(matches)}")
# Count the occurrences of successful responses (status code 200 or 201)
success_count = sum(1 for match in matches if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match)
success_count = sum(
1
for match in matches
if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match
)
# Print the count of successful responses
print(f"Count of successful responses from DynamoDB: {success_count}")
assert success_count == 3 # Expect 3 success logs from dynamoDB
assert success_count == 3 # Expect 3 success logs from dynamoDB
def test_dynamo_logging():
def test_dynamo_logging():
# all dynamodb requests need to be in one test function
# since we are modifying stdout, and pytests runs tests in parallel
try:
try:
# pre
# redirect stdout to log_file
@ -69,44 +78,44 @@ def test_dynamo_logging():
litellm.set_verbose = True
original_stdout, log_file, file_name = pre_request()
print("Testing async dynamoDB logging")
async def _test():
return await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}],
messages=[{"role": "user", "content": "This is a test"}],
max_tokens=100,
temperature=0.7,
user = "ishaan-2"
user="ishaan-2",
)
response = asyncio.run(_test())
print(f"response: {response}")
# streaming + async
# streaming + async
async def _test2():
response = await litellm.acompletion(
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}],
messages=[{"role": "user", "content": "This is a test"}],
max_tokens=10,
temperature=0.7,
user = "ishaan-2",
stream=True
user="ishaan-2",
stream=True,
)
async for chunk in response:
pass
asyncio.run(_test2())
# aembedding()
async def _test3():
return await litellm.aembedding(
model="text-embedding-ada-002",
input = ["hi"],
user = "ishaan-2"
model="text-embedding-ada-002", input=["hi"], user="ishaan-2"
)
response = asyncio.run(_test3())
time.sleep(1)
except Exception as e:
except Exception as e:
pytest.fail(f"An exception occurred - {e}")
finally:
# post, close log file and verify
@ -117,4 +126,5 @@ def test_dynamo_logging():
verify_log_file(file_name)
print("Passed! Testing async dynamoDB logging")
# test_dynamo_logging_async()

View file

@ -14,39 +14,49 @@ from litellm import embedding, completion
litellm.set_verbose = False
def test_openai_embedding():
try:
litellm.set_verbose=True
litellm.set_verbose = True
response = embedding(
model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"],
metadata = {"anything": "good day"}
model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"],
metadata={"anything": "good day"},
)
litellm_response = dict(response)
litellm_response_keys = set(litellm_response.keys())
litellm_response_keys.discard('_response_ms')
litellm_response_keys.discard("_response_ms")
print(litellm_response_keys)
print("LiteLLM Response\n")
# print(litellm_response)
# same request with OpenAI 1.0+
# same request with OpenAI 1.0+
import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'])
client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
response = client.embeddings.create(
model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"]
model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"],
)
response = dict(response)
openai_response_keys = set(response.keys())
print(openai_response_keys)
assert litellm_response_keys == openai_response_keys # ENSURE the Keys in litellm response is exactly what the openai package returns
assert len(litellm_response["data"]) == 2 # expect two embedding responses from litellm_response since input had two
assert (
litellm_response_keys == openai_response_keys
) # ENSURE the Keys in litellm response is exactly what the openai package returns
assert (
len(litellm_response["data"]) == 2
) # expect two embedding responses from litellm_response since input had two
print(openai_response_keys)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_openai_embedding()
def test_openai_azure_embedding_simple():
try:
response = embedding(
@ -55,12 +65,15 @@ def test_openai_azure_embedding_simple():
)
print(response)
response_keys = set(dict(response).keys())
response_keys.discard('_response_ms')
assert set(["usage", "model", "object", "data"]) == set(response_keys) #assert litellm response has expected keys from OpenAI embedding response
response_keys.discard("_response_ms")
assert set(["usage", "model", "object", "data"]) == set(
response_keys
) # assert litellm response has expected keys from OpenAI embedding response
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_openai_azure_embedding_simple()
@ -69,41 +82,50 @@ def test_openai_azure_embedding_timeouts():
response = embedding(
model="azure/azure-embedding-model",
input=["good morning from litellm"],
timeout=0.00001
timeout=0.00001,
)
print(response)
except openai.APITimeoutError:
print("Good job got timeout error!")
pass
except Exception as e:
pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}")
pytest.fail(
f"Expected timeout error, did not get the correct error. Instead got {e}"
)
# test_openai_azure_embedding_timeouts()
def test_openai_embedding_timeouts():
try:
response = embedding(
model="text-embedding-ada-002",
input=["good morning from litellm"],
timeout=0.00001
timeout=0.00001,
)
print(response)
except openai.APITimeoutError:
print("Good job got OpenAI timeout error!")
pass
except Exception as e:
pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}")
pytest.fail(
f"Expected timeout error, did not get the correct error. Instead got {e}"
)
# test_openai_embedding_timeouts()
def test_openai_azure_embedding():
try:
api_key = os.environ['AZURE_API_KEY']
api_base = os.environ['AZURE_API_BASE']
api_version = os.environ['AZURE_API_VERSION']
api_key = os.environ["AZURE_API_KEY"]
api_base = os.environ["AZURE_API_BASE"]
api_version = os.environ["AZURE_API_VERSION"]
os.environ['AZURE_API_VERSION'] = ""
os.environ['AZURE_API_BASE'] = ""
os.environ['AZURE_API_KEY'] = ""
os.environ["AZURE_API_VERSION"] = ""
os.environ["AZURE_API_BASE"] = ""
os.environ["AZURE_API_KEY"] = ""
response = embedding(
model="azure/azure-embedding-model",
@ -114,137 +136,179 @@ def test_openai_azure_embedding():
)
print(response)
os.environ['AZURE_API_VERSION'] = api_version
os.environ['AZURE_API_BASE'] = api_base
os.environ['AZURE_API_KEY'] = api_key
os.environ["AZURE_API_VERSION"] = api_version
os.environ["AZURE_API_BASE"] = api_base
os.environ["AZURE_API_KEY"] = api_key
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_openai_azure_embedding()
# test_openai_embedding()
def test_cohere_embedding():
try:
# litellm.set_verbose=True
response = embedding(
model="embed-english-v2.0", input=["good morning from litellm", "this is another item"]
model="embed-english-v2.0",
input=["good morning from litellm", "this is another item"],
)
print(f"response:", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_cohere_embedding()
def test_cohere_embedding3():
try:
litellm.set_verbose=True
litellm.set_verbose = True
response = embedding(
model="embed-english-v3.0",
input=["good morning from litellm", "this is another item"],
model="embed-english-v3.0",
input=["good morning from litellm", "this is another item"],
)
print(f"response:", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_cohere_embedding3()
def test_bedrock_embedding_titan():
try:
litellm.set_verbose=True
litellm.set_verbose = True
response = embedding(
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
"lets test a second string for good measure"]
model="amazon.titan-embed-text-v1",
input=[
"good morning from litellm, attempting to embed data",
"lets test a second string for good measure",
],
)
print(f"response:", response)
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list"
print(f"type of first embedding:", type(response['data'][0]['embedding'][0]))
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
assert isinstance(
response["data"][0]["embedding"], list
), "Expected response to be a list"
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
assert all(
isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_bedrock_embedding_titan()
def test_bedrock_embedding_cohere():
try:
litellm.set_verbose=False
litellm.set_verbose = False
response = embedding(
model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"],
aws_region_name="os.environ/AWS_REGION_NAME_2"
model="cohere.embed-multilingual-v3",
input=[
"good morning from litellm, attempting to embed data",
"lets test a second string for good measure",
],
aws_region_name="os.environ/AWS_REGION_NAME_2",
)
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list"
print(f"type of first embedding:", type(response['data'][0]['embedding'][0]))
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
assert isinstance(
response["data"][0]["embedding"], list
), "Expected response to be a list"
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
assert all(
isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats"
# print(f"response:", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_bedrock_embedding_cohere()
# comment out hf tests - since hf endpoints are unstable
def test_hf_embedding():
try:
# huggingface/microsoft/codebert-base
# huggingface/facebook/bart-large
response = embedding(
model="huggingface/sentence-transformers/all-MiniLM-L6-v2", input=["good morning from litellm", "this is another item"]
model="huggingface/sentence-transformers/all-MiniLM-L6-v2",
input=["good morning from litellm", "this is another item"],
)
print(f"response:", response)
except Exception as e:
# Note: Huggingface inference API is unstable and fails with "model loading errors all the time"
pass
# test_hf_embedding()
# test async embeddings
def test_aembedding():
try:
import asyncio
async def embedding_call():
try:
response = await litellm.aembedding(
model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"]
model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"],
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
asyncio.run(embedding_call())
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_aembedding()
def test_aembedding_azure():
try:
import asyncio
async def embedding_call():
try:
response = await litellm.aembedding(
model="azure/azure-embedding-model",
input=["good morning from litellm", "this is another item"]
model="azure/azure-embedding-model",
input=["good morning from litellm", "this is another item"],
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
asyncio.run(embedding_call())
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_aembedding_azure()
def test_sagemaker_embeddings():
try:
response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"])
def test_sagemaker_embeddings():
try:
response = litellm.embedding(
model="sagemaker/berri-benchmarking-gpt-j-6b-fp16",
input=["good morning from litellm", "this is another item"],
)
print(f"response: {response}")
except Exception as e:
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_sagemaker_embeddings()
# def local_proxy_embeddings():
# litellm.set_verbose=True
# response = embedding(
# model="openai/custom_embedding",
# model="openai/custom_embedding",
# input=["good morning from litellm"],
# api_base="http://0.0.0.0:8000/"
# )

View file

@ -11,17 +11,18 @@ import litellm
from litellm import (
embedding,
completion,
# AuthenticationError,
# AuthenticationError,
ContextWindowExceededError,
# RateLimitError,
# ServiceUnavailableError,
# OpenAIError,
# RateLimitError,
# ServiceUnavailableError,
# OpenAIError,
)
from concurrent.futures import ThreadPoolExecutor
import pytest
litellm.vertex_project = "pathrise-convert-1606954137718"
litellm.vertex_location = "us-central1"
litellm.num_retries=0
litellm.num_retries = 0
# litellm.failure_callback = ["sentry"]
#### What this tests ####
@ -36,7 +37,8 @@ litellm.num_retries=0
models = ["command-nightly"]
# Test 1: Context Window Errors
# Test 1: Context Window Errors
@pytest.mark.parametrize("model", models)
def test_context_window(model):
print("Testing context window error")
@ -52,17 +54,27 @@ def test_context_window(model):
print(f"Worked!")
except RateLimitError:
print("RateLimited!")
except Exception as e:
except Exception as e:
print(f"{e}")
pytest.fail(f"An error occcurred - {e}")
@pytest.mark.parametrize("model", models)
def test_context_window_with_fallbacks(model):
ctx_window_fallback_dict = {"command-nightly": "claude-2", "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k", "azure/chatgpt-v-2": "gpt-3.5-turbo-16k"}
ctx_window_fallback_dict = {
"command-nightly": "claude-2",
"gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k",
"azure/chatgpt-v-2": "gpt-3.5-turbo-16k",
}
sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}]
completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict)
completion(
model=model,
messages=messages,
context_window_fallback_dict=ctx_window_fallback_dict,
)
# for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model)
@ -98,7 +110,9 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["AI21_API_KEY"] = "bad-key"
elif "togethercomputer" in model:
temporary_key = os.environ["TOGETHERAI_API_KEY"]
os.environ["TOGETHERAI_API_KEY"] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a"
os.environ[
"TOGETHERAI_API_KEY"
] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a"
elif model in litellm.openrouter_models:
temporary_key = os.environ["OPENROUTER_API_KEY"]
os.environ["OPENROUTER_API_KEY"] = "bad-key"
@ -115,9 +129,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
temporary_key = os.environ["REPLICATE_API_KEY"]
os.environ["REPLICATE_API_KEY"] = "bad-key"
print(f"model: {model}")
response = completion(
model=model, messages=messages
)
response = completion(model=model, messages=messages)
print(f"response: {response}")
except AuthenticationError as e:
print(f"AuthenticationError Caught Exception - {str(e)}")
@ -148,23 +160,25 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["REPLICATE_API_KEY"] = temporary_key
elif "j2" in model:
os.environ["AI21_API_KEY"] = temporary_key
elif ("togethercomputer" in model):
elif "togethercomputer" in model:
os.environ["TOGETHERAI_API_KEY"] = temporary_key
elif model in litellm.aleph_alpha_models:
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
elif model in litellm.nlp_cloud_models:
os.environ["NLP_CLOUD_API_KEY"] = temporary_key
elif "bedrock" in model:
elif "bedrock" in model:
os.environ["AWS_ACCESS_KEY_ID"] = temporary_aws_access_key
os.environ["AWS_REGION_NAME"] = temporary_aws_region_name
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
return
# for model in litellm.models_by_provider["bedrock"]:
# invalid_auth(model=model)
# invalid_auth(model="command-nightly")
# Test 3: Invalid Request Error
# Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models)
def test_invalid_request_error(model):
messages = [{"content": "hey, how's it going?", "role": "user"}]
@ -173,23 +187,18 @@ def test_invalid_request_error(model):
completion(model=model, messages=messages, max_tokens="hello world")
def test_completion_azure_exception():
try:
import openai
print("azure gpt-3.5 test\n\n")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
old_azure_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "good morning"
response = completion(
model="azure/chatgpt-v-2",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
os.environ["AZURE_API_KEY"] = old_azure_key
print(f"response: {response}")
@ -199,25 +208,24 @@ def test_completion_azure_exception():
print("good job got the correct error for azure when key not set")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_azure_exception()
async def asynctest_completion_azure_exception():
try:
import openai
import litellm
print("azure gpt-3.5 test\n\n")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
old_azure_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "good morning"
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
print(f"response: {response}")
print(response)
@ -229,6 +237,8 @@ async def asynctest_completion_azure_exception():
print("Got wrong exception")
print("exception", e)
pytest.fail(f"Error occurred: {e}")
# import asyncio
# asyncio.run(
# asynctest_completion_azure_exception()
@ -239,19 +249,17 @@ def asynctest_completion_openai_exception_bad_model():
try:
import openai
import litellm, asyncio
print("azure exception bad model\n\n")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
async def test():
response = await litellm.acompletion(
model="openai/gpt-6",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
asyncio.run(test())
except openai.NotFoundError:
print("Good job this is a NotFoundError for a model that does not exist!")
@ -261,27 +269,25 @@ def asynctest_completion_openai_exception_bad_model():
assert isinstance(e, openai.BadRequestError)
pytest.fail(f"Error occurred: {e}")
# asynctest_completion_openai_exception_bad_model()
# asynctest_completion_openai_exception_bad_model()
def asynctest_completion_azure_exception_bad_model():
try:
import openai
import litellm, asyncio
print("azure exception bad model\n\n")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
async def test():
response = await litellm.acompletion(
model="azure/gpt-12",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
asyncio.run(test())
except openai.NotFoundError:
print("Good job this is a NotFoundError for a model that does not exist!")
@ -290,25 +296,23 @@ def asynctest_completion_azure_exception_bad_model():
print("Raised wrong type of exception", type(e))
pytest.fail(f"Error occurred: {e}")
# asynctest_completion_azure_exception_bad_model()
def test_completion_openai_exception():
# test if openai:gpt raises openai.AuthenticationError
try:
import openai
print("openai gpt-3.5 test\n\n")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
old_azure_key = os.environ["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = "good morning"
response = completion(
model="gpt-4",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
print(f"response: {response}")
print(response)
@ -317,25 +321,24 @@ def test_completion_openai_exception():
print("OpenAI: good job got the correct error for openai when key not set")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_openai_exception()
def test_completion_mistral_exception():
# test if mistral/mistral-tiny raises openai.AuthenticationError
try:
import openai
print("Testing mistral ai exception mapping")
litellm.set_verbose=True
litellm.set_verbose = True
## Test azure call
old_azure_key = os.environ["MISTRAL_API_KEY"]
os.environ["MISTRAL_API_KEY"] = "good morning"
response = completion(
model="mistral/mistral-tiny",
messages=[
{
"role": "user",
"content": "hello"
}
],
messages=[{"role": "user", "content": "hello"}],
)
print(f"response: {response}")
print(response)
@ -344,11 +347,11 @@ def test_completion_mistral_exception():
print("good job got the correct error for openai when key not set")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_mistral_exception()
# # test_invalid_request_error(model="command-nightly")
# # Test 3: Rate Limit Errors
# def test_model_call(model):
@ -387,4 +390,4 @@ def test_completion_mistral_exception():
# counts[result] += 1
# accuracy_score = counts[True]/(counts[True] + counts[False])
# print(f"accuracy_score: {accuracy_score}")
# print(f"accuracy_score: {accuracy_score}")

Some files were not shown because too many files have changed in this diff Show more