refactor: add black formatting

This commit is contained in:
Krrish Dholakia 2023-12-25 14:10:38 +05:30
parent b87d630b0a
commit 4905929de3
156 changed files with 19723 additions and 10869 deletions

View file

@ -1,9 +1,13 @@
repos: repos:
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: 3.8.4 # The version of flake8 to use rev: 3.8.4 # The version of flake8 to use
hooks: hooks:
- id: flake8 - id: flake8
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/ exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/|^litellm/proxy/tests/
additional_dependencies: [flake8-print] additional_dependencies: [flake8-print]
files: litellm/.*\.py files: litellm/.*\.py
- repo: local - repo: local

View file

@ -9,33 +9,37 @@ import os
# Define the list of models to benchmark # Define the list of models to benchmark
# select any LLM listed here: https://docs.litellm.ai/docs/providers # select any LLM listed here: https://docs.litellm.ai/docs/providers
models = ['gpt-3.5-turbo', 'claude-2'] models = ["gpt-3.5-turbo", "claude-2"]
# Enter LLM API keys # Enter LLM API keys
# https://docs.litellm.ai/docs/providers # https://docs.litellm.ai/docs/providers
os.environ['OPENAI_API_KEY'] = "" os.environ["OPENAI_API_KEY"] = ""
os.environ['ANTHROPIC_API_KEY'] = "" os.environ["ANTHROPIC_API_KEY"] = ""
# List of questions to benchmark (replace with your questions) # List of questions to benchmark (replace with your questions)
questions = [ questions = ["When will BerriAI IPO?", "When will LiteLLM hit $100M ARR?"]
"When will BerriAI IPO?",
"When will LiteLLM hit $100M ARR?"
]
# Enter your system prompt here # Enter your system prompt here
system_prompt = """ system_prompt = """
You are LiteLLMs helpful assistant You are LiteLLMs helpful assistant
""" """
@click.command() @click.command()
@click.option('--system-prompt', default="You are a helpful assistant that can answer questions.", help="System prompt for the conversation.") @click.option(
"--system-prompt",
default="You are a helpful assistant that can answer questions.",
help="System prompt for the conversation.",
)
def main(system_prompt): def main(system_prompt):
for question in questions: for question in questions:
data = [] # Data for the current question data = [] # Data for the current question
with tqdm(total=len(models)) as pbar: with tqdm(total=len(models)) as pbar:
for model in models: for model in models:
colored_description = colored(f"Running question: {question} for model: {model}", 'green') colored_description = colored(
f"Running question: {question} for model: {model}", "green"
)
pbar.set_description(colored_description) pbar.set_description(colored_description)
start_time = time.time() start_time = time.time()
@ -44,35 +48,43 @@ def main(system_prompt):
max_tokens=500, max_tokens=500,
messages=[ messages=[
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": question} {"role": "user", "content": question},
], ],
) )
end = time.time() end = time.time()
total_time = end - start_time total_time = end - start_time
cost = completion_cost(completion_response=response) cost = completion_cost(completion_response=response)
raw_response = response['choices'][0]['message']['content'] raw_response = response["choices"][0]["message"]["content"]
data.append({ data.append(
'Model': colored(model, 'light_blue'), {
'Response': raw_response, # Colorize the response "Model": colored(model, "light_blue"),
'ResponseTime': colored(f"{total_time:.2f} seconds", "red"), "Response": raw_response, # Colorize the response
'Cost': colored(f"${cost:.6f}", 'green'), # Colorize the cost "ResponseTime": colored(f"{total_time:.2f} seconds", "red"),
}) "Cost": colored(f"${cost:.6f}", "green"), # Colorize the cost
}
)
pbar.update(1) pbar.update(1)
# Separate headers from the data # Separate headers from the data
headers = ['Model', 'Response', 'Response Time (seconds)', 'Cost ($)'] headers = ["Model", "Response", "Response Time (seconds)", "Cost ($)"]
colwidths = [15, 80, 15, 10] colwidths = [15, 80, 15, 10]
# Create a nicely formatted table for the current question # Create a nicely formatted table for the current question
table = tabulate([list(d.values()) for d in data], headers, tablefmt="grid", maxcolwidths=colwidths) table = tabulate(
[list(d.values()) for d in data],
headers,
tablefmt="grid",
maxcolwidths=colwidths,
)
# Print the table for the current question # Print the table for the current question
colored_question = colored(question, 'green') colored_question = colored(question, "green")
click.echo(f"\nBenchmark Results for '{colored_question}':") click.echo(f"\nBenchmark Results for '{colored_question}':")
click.echo(table) # Display the formatted table click.echo(table) # Display the formatted table
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View file

@ -1,25 +1,22 @@
import sys, os import sys, os
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import litellm import litellm
from litellm import embedding, completion, completion_cost from litellm import embedding, completion, completion_cost
from autoevals.llm import * from autoevals.llm import *
################### ###################
import litellm import litellm
# litellm completion call # litellm completion call
question = "which country has the highest population" question = "which country has the highest population"
response = litellm.completion( response = litellm.completion(
model = "gpt-3.5-turbo", model="gpt-3.5-turbo",
messages = [ messages=[{"role": "user", "content": question}],
{
"role": "user",
"content": question
}
],
) )
print(response) print(response)
# use the auto eval Factuality() evaluator # use the auto eval Factuality() evaluator
@ -27,9 +24,11 @@ print(response)
print("calling evaluator") print("calling evaluator")
evaluator = Factuality() evaluator = Factuality()
result = evaluator( result = evaluator(
output=response.choices[0]["message"]["content"], # response from litellm.completion() output=response.choices[0]["message"][
expected="India", # expected output "content"
input=question # question passed to litellm.completion ], # response from litellm.completion()
expected="India", # expected output
input=question, # question passed to litellm.completion
) )
print(result) print(result)

View file

@ -4,9 +4,10 @@ from flask_cors import CORS
import traceback import traceback
import litellm import litellm
from util import handle_error from util import handle_error
from litellm import completion from litellm import completion
import os, dotenv, time import os, dotenv, time
import json import json
dotenv.load_dotenv() dotenv.load_dotenv()
# TODO: set your keys in .env or here: # TODO: set your keys in .env or here:
@ -19,57 +20,72 @@ verbose = True
# litellm.caching_with_models = True # CACHING: caching_with_models Keys in the cache are messages + model. - to learn more: https://docs.litellm.ai/docs/caching/ # litellm.caching_with_models = True # CACHING: caching_with_models Keys in the cache are messages + model. - to learn more: https://docs.litellm.ai/docs/caching/
######### PROMPT LOGGING ########## ######### PROMPT LOGGING ##########
os.environ["PROMPTLAYER_API_KEY"] = "" # set your promptlayer key here - https://promptlayer.com/ os.environ[
"PROMPTLAYER_API_KEY"
] = "" # set your promptlayer key here - https://promptlayer.com/
# set callbacks # set callbacks
litellm.success_callback = ["promptlayer"] litellm.success_callback = ["promptlayer"]
############ HELPER FUNCTIONS ################################### ############ HELPER FUNCTIONS ###################################
def print_verbose(print_statement): def print_verbose(print_statement):
if verbose: if verbose:
print(print_statement) print(print_statement)
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
@app.route('/')
@app.route("/")
def index(): def index():
return 'received!', 200 return "received!", 200
def data_generator(response): def data_generator(response):
for chunk in response: for chunk in response:
yield f"data: {json.dumps(chunk)}\n\n" yield f"data: {json.dumps(chunk)}\n\n"
@app.route('/chat/completions', methods=["POST"])
@app.route("/chat/completions", methods=["POST"])
def api_completion(): def api_completion():
data = request.json data = request.json
start_time = time.time() start_time = time.time()
if data.get('stream') == "True": if data.get("stream") == "True":
data['stream'] = True # convert to boolean data["stream"] = True # convert to boolean
try: try:
if "prompt" not in data: if "prompt" not in data:
raise ValueError("data needs to have prompt") raise ValueError("data needs to have prompt")
data["model"] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct data[
"model"
] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct
# COMPLETION CALL # COMPLETION CALL
system_prompt = "Only respond to questions about code. Say 'I don't know' to anything outside of that." system_prompt = "Only respond to questions about code. Say 'I don't know' to anything outside of that."
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": data.pop("prompt")}] messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": data.pop("prompt")},
]
data["messages"] = messages data["messages"] = messages
print(f"data: {data}") print(f"data: {data}")
response = completion(**data) response = completion(**data)
## LOG SUCCESS ## LOG SUCCESS
end_time = time.time() end_time = time.time()
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses if (
return Response(data_generator(response), mimetype='text/event-stream') "stream" in data and data["stream"] == True
): # use generate_responses to stream responses
return Response(data_generator(response), mimetype="text/event-stream")
except Exception as e: except Exception as e:
# call handle_error function # call handle_error function
print_verbose(f"Got Error api_completion(): {traceback.format_exc()}") print_verbose(f"Got Error api_completion(): {traceback.format_exc()}")
## LOG FAILURE ## LOG FAILURE
end_time = time.time() end_time = time.time()
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
return handle_error(data=data) return handle_error(data=data)
return response return response
@app.route('/get_models', methods=["POST"])
@app.route("/get_models", methods=["POST"])
def get_models(): def get_models():
try: try:
return litellm.model_list return litellm.model_list
@ -78,7 +94,8 @@ def get_models():
response = {"error": str(e)} response = {"error": str(e)}
return response, 200 return response, 200
if __name__ == "__main__":
from waitress import serve
serve(app, host="0.0.0.0", port=4000, threads=500)
if __name__ == "__main__":
from waitress import serve
serve(app, host="0.0.0.0", port=4000, threads=500)

View file

@ -3,27 +3,28 @@ from urllib.parse import urlparse, parse_qs
def get_next_url(response): def get_next_url(response):
""" """
Function to get 'next' url from Link header Function to get 'next' url from Link header
:param response: response from requests :param response: response from requests
:return: next url or None :return: next url or None
""" """
if 'link' not in response.headers: if "link" not in response.headers:
return None return None
headers = response.headers headers = response.headers
next_url = headers['Link'] next_url = headers["Link"]
print(next_url) print(next_url)
start_index = next_url.find("<") start_index = next_url.find("<")
end_index = next_url.find(">") end_index = next_url.find(">")
return next_url[1:end_index]
return next_url[1:end_index]
def get_models(url): def get_models(url):
""" """
Function to retrieve all models from paginated endpoint Function to retrieve all models from paginated endpoint
:param url: base url to make GET request :param url: base url to make GET request
:return: list of all models :return: list of all models
""" """
models = [] models = []
while url: while url:
@ -36,19 +37,21 @@ def get_models(url):
models.extend(payload) models.extend(payload)
return models return models
def get_cleaned_models(models): def get_cleaned_models(models):
""" """
Function to clean retrieved models Function to clean retrieved models
:param models: list of retrieved models :param models: list of retrieved models
:return: list of cleaned models :return: list of cleaned models
""" """
cleaned_models = [] cleaned_models = []
for model in models: for model in models:
cleaned_models.append(model["id"]) cleaned_models.append(model["id"])
return cleaned_models return cleaned_models
# Get text-generation models # Get text-generation models
url = 'https://huggingface.co/api/models?filter=text-generation-inference' url = "https://huggingface.co/api/models?filter=text-generation-inference"
text_generation_models = get_models(url) text_generation_models = get_models(url)
cleaned_text_generation_models = get_cleaned_models(text_generation_models) cleaned_text_generation_models = get_cleaned_models(text_generation_models)
@ -56,7 +59,7 @@ print(cleaned_text_generation_models)
# Get conversational models # Get conversational models
url = 'https://huggingface.co/api/models?filter=conversational' url = "https://huggingface.co/api/models?filter=conversational"
conversational_models = get_models(url) conversational_models = get_models(url)
cleaned_conversational_models = get_cleaned_models(conversational_models) cleaned_conversational_models = get_cleaned_models(conversational_models)
@ -65,19 +68,23 @@ print(cleaned_conversational_models)
def write_to_txt(cleaned_models, filename): def write_to_txt(cleaned_models, filename):
""" """
Function to write the contents of a list to a text file Function to write the contents of a list to a text file
:param cleaned_models: list of cleaned models :param cleaned_models: list of cleaned models
:param filename: name of the text file :param filename: name of the text file
""" """
with open(filename, 'w') as f: with open(filename, "w") as f:
for item in cleaned_models: for item in cleaned_models:
f.write("%s\n" % item) f.write("%s\n" % item)
# Write contents of cleaned_text_generation_models to text_generation_models.txt # Write contents of cleaned_text_generation_models to text_generation_models.txt
write_to_txt(cleaned_text_generation_models, 'huggingface_llms_metadata/hf_text_generation_models.txt') write_to_txt(
cleaned_text_generation_models,
"huggingface_llms_metadata/hf_text_generation_models.txt",
)
# Write contents of cleaned_conversational_models to conversational_models.txt # Write contents of cleaned_conversational_models to conversational_models.txt
write_to_txt(cleaned_conversational_models, 'huggingface_llms_metadata/hf_conversational_models.txt') write_to_txt(
cleaned_conversational_models,
"huggingface_llms_metadata/hf_conversational_models.txt",
)

View file

@ -1,4 +1,3 @@
import openai import openai
api_base = f"http://0.0.0.0:8000" api_base = f"http://0.0.0.0:8000"
@ -8,29 +7,29 @@ openai.api_key = "temp-key"
print(openai.api_base) print(openai.api_base)
print(f'LiteLLM: response from proxy with streaming') print(f"LiteLLM: response from proxy with streaming")
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model="ollama/llama2", model="ollama/llama2",
messages = [ messages=[
{ {
"role": "user", "role": "user",
"content": "this is a test request, acknowledge that you got it" "content": "this is a test request, acknowledge that you got it",
} }
], ],
stream=True stream=True,
) )
for chunk in response: for chunk in response:
print(f'LiteLLM: streaming response from proxy {chunk}') print(f"LiteLLM: streaming response from proxy {chunk}")
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model="ollama/llama2", model="ollama/llama2",
messages = [ messages=[
{ {
"role": "user", "role": "user",
"content": "this is a test request, acknowledge that you got it" "content": "this is a test request, acknowledge that you got it",
} }
] ],
) )
print(f'LiteLLM: response from proxy {response}') print(f"LiteLLM: response from proxy {response}")

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router from litellm import Router
import litellm import litellm
litellm.set_verbose=False
litellm.set_verbose = False
os.environ.pop("AZURE_AD_TOKEN") os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments model_list = [
"model_name": "gpt-3.5-turbo", # model alias { # list of model deployments
"litellm_params": { # params for litellm completion/embedding call "model_name": "gpt-3.5-turbo", # model alias
"model": "azure/chatgpt-v-2", # actual model name "litellm_params": { # params for litellm completion/embedding call
"api_key": os.getenv("AZURE_API_KEY"), "model": "azure/chatgpt-v-2", # actual model name
"api_version": os.getenv("AZURE_API_VERSION"), "api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE") "api_version": os.getenv("AZURE_API_VERSION"),
} "api_base": os.getenv("AZURE_API_BASE"),
}, { },
"model_name": "gpt-3.5-turbo", },
"litellm_params": { # params for litellm completion/embedding call {
"model": "azure/chatgpt-functioncalling", "model_name": "gpt-3.5-turbo",
"api_key": os.getenv("AZURE_API_KEY"), "litellm_params": { # params for litellm completion/embedding call
"api_version": os.getenv("AZURE_API_VERSION"), "model": "azure/chatgpt-functioncalling",
"api_base": os.getenv("AZURE_API_BASE") "api_key": os.getenv("AZURE_API_KEY"),
} "api_version": os.getenv("AZURE_API_VERSION"),
}, { "api_base": os.getenv("AZURE_API_BASE"),
"model_name": "gpt-3.5-turbo", },
"litellm_params": { # params for litellm completion/embedding call },
"model": "gpt-3.5-turbo", {
"api_key": os.getenv("OPENAI_API_KEY"), "model_name": "gpt-3.5-turbo",
} "litellm_params": { # params for litellm completion/embedding call
}] "model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
router = Router(model_list=model_list) router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"] file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = [] questions = []
for file_path in file_paths: for file_path in file_paths:
try: try:
print(file_path) print(file_path)
with open(file_path, 'r') as file: with open(file_path, "r") as file:
content = file.read() content = file.read()
questions.append(content) questions.append(content)
except FileNotFoundError as e: except FileNotFoundError as e:
@ -59,10 +68,9 @@ for file_path in file_paths:
# print(q) # print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere # Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions # show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
import concurrent.futures import concurrent.futures
import random import random
@ -74,10 +82,18 @@ def make_openai_completion(question):
try: try:
start_time = time.time() start_time = time.time()
import openai import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000"
) # base_url="http://0.0.0.0:8000",
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}], messages=[
{
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
}
],
) )
print(response) print(response)
end_time = time.time() end_time = time.time()
@ -92,11 +108,10 @@ def make_openai_completion(question):
except Exception as e: except Exception as e:
# Log exceptions for failed calls # Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file: with open("error_log.txt", "a") as error_log_file:
error_log_file.write( error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
return None return None
# Number of concurrent calls (you can adjust this) # Number of concurrent calls (you can adjust this)
concurrent_calls = 100 concurrent_calls = 100
@ -133,4 +148,3 @@ with open("request_log.txt", "r") as log_file:
with open("error_log.txt", "r") as error_log_file: with open("error_log.txt", "r") as error_log_file:
print("\nError Log:\n", error_log_file.read()) print("\nError Log:\n", error_log_file.read())

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router from litellm import Router
import litellm import litellm
litellm.set_verbose=False
litellm.set_verbose = False
# os.environ.pop("AZURE_AD_TOKEN") # os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments model_list = [
"model_name": "gpt-3.5-turbo", # model alias { # list of model deployments
"litellm_params": { # params for litellm completion/embedding call "model_name": "gpt-3.5-turbo", # model alias
"model": "azure/chatgpt-v-2", # actual model name "litellm_params": { # params for litellm completion/embedding call
"api_key": os.getenv("AZURE_API_KEY"), "model": "azure/chatgpt-v-2", # actual model name
"api_version": os.getenv("AZURE_API_VERSION"), "api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE") "api_version": os.getenv("AZURE_API_VERSION"),
} "api_base": os.getenv("AZURE_API_BASE"),
}, { },
"model_name": "gpt-3.5-turbo", },
"litellm_params": { # params for litellm completion/embedding call {
"model": "azure/chatgpt-functioncalling", "model_name": "gpt-3.5-turbo",
"api_key": os.getenv("AZURE_API_KEY"), "litellm_params": { # params for litellm completion/embedding call
"api_version": os.getenv("AZURE_API_VERSION"), "model": "azure/chatgpt-functioncalling",
"api_base": os.getenv("AZURE_API_BASE") "api_key": os.getenv("AZURE_API_KEY"),
} "api_version": os.getenv("AZURE_API_VERSION"),
}, { "api_base": os.getenv("AZURE_API_BASE"),
"model_name": "gpt-3.5-turbo", },
"litellm_params": { # params for litellm completion/embedding call },
"model": "gpt-3.5-turbo", {
"api_key": os.getenv("OPENAI_API_KEY"), "model_name": "gpt-3.5-turbo",
} "litellm_params": { # params for litellm completion/embedding call
}] "model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
router = Router(model_list=model_list) router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"] file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = [] questions = []
for file_path in file_paths: for file_path in file_paths:
try: try:
print(file_path) print(file_path)
with open(file_path, 'r') as file: with open(file_path, "r") as file:
content = file.read() content = file.read()
questions.append(content) questions.append(content)
except FileNotFoundError as e: except FileNotFoundError as e:
@ -59,10 +68,9 @@ for file_path in file_paths:
# print(q) # print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere # Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions # show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
import concurrent.futures import concurrent.futures
import random import random
@ -76,9 +84,12 @@ def make_openai_completion(question):
import requests import requests
data = { data = {
'model': 'gpt-3.5-turbo', "model": "gpt-3.5-turbo",
'messages': [ "messages": [
{'role': 'system', 'content': f'You are a helpful assistant. Answer this question{question}'}, {
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
},
], ],
} }
response = requests.post("http://0.0.0.0:8000/queue/request", json=data) response = requests.post("http://0.0.0.0:8000/queue/request", json=data)
@ -89,8 +100,8 @@ def make_openai_completion(question):
log_file.write( log_file.write(
f"Question: {question[:100]}\nResponse ID: {response.get('id', 'N/A')} Url: {response.get('url', 'N/A')}\nTime: {end_time - start_time:.2f} seconds\n\n" f"Question: {question[:100]}\nResponse ID: {response.get('id', 'N/A')} Url: {response.get('url', 'N/A')}\nTime: {end_time - start_time:.2f} seconds\n\n"
) )
# polling the url # polling the url
while True: while True:
try: try:
url = response["url"] url = response["url"]
@ -107,7 +118,9 @@ def make_openai_completion(question):
) )
break break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") print(
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}"
)
time.sleep(0.5) time.sleep(0.5)
except Exception as e: except Exception as e:
print("got exception in polling", e) print("got exception in polling", e)
@ -117,11 +130,10 @@ def make_openai_completion(question):
except Exception as e: except Exception as e:
# Log exceptions for failed calls # Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file: with open("error_log.txt", "a") as error_log_file:
error_log_file.write( error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
return None return None
# Number of concurrent calls (you can adjust this) # Number of concurrent calls (you can adjust this)
concurrent_calls = 10 concurrent_calls = 10
@ -142,7 +154,7 @@ successful_calls = 0
failed_calls = 0 failed_calls = 0
for future in futures: for future in futures:
if future.done(): if future.done():
if future.result() is not None: if future.result() is not None:
successful_calls += 1 successful_calls += 1
else: else:
@ -152,4 +164,3 @@ print(f"Load test Summary:")
print(f"Total Requests: {concurrent_calls}") print(f"Total Requests: {concurrent_calls}")
print(f"Successful Calls: {successful_calls}") print(f"Successful Calls: {successful_calls}")
print(f"Failed Calls: {failed_calls}") print(f"Failed Calls: {failed_calls}")

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router from litellm import Router
import litellm import litellm
litellm.set_verbose=False
litellm.set_verbose = False
os.environ.pop("AZURE_AD_TOKEN") os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments model_list = [
"model_name": "gpt-3.5-turbo", # model alias { # list of model deployments
"litellm_params": { # params for litellm completion/embedding call "model_name": "gpt-3.5-turbo", # model alias
"model": "azure/chatgpt-v-2", # actual model name "litellm_params": { # params for litellm completion/embedding call
"api_key": os.getenv("AZURE_API_KEY"), "model": "azure/chatgpt-v-2", # actual model name
"api_version": os.getenv("AZURE_API_VERSION"), "api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE") "api_version": os.getenv("AZURE_API_VERSION"),
} "api_base": os.getenv("AZURE_API_BASE"),
}, { },
"model_name": "gpt-3.5-turbo", },
"litellm_params": { # params for litellm completion/embedding call {
"model": "azure/chatgpt-functioncalling", "model_name": "gpt-3.5-turbo",
"api_key": os.getenv("AZURE_API_KEY"), "litellm_params": { # params for litellm completion/embedding call
"api_version": os.getenv("AZURE_API_VERSION"), "model": "azure/chatgpt-functioncalling",
"api_base": os.getenv("AZURE_API_BASE") "api_key": os.getenv("AZURE_API_KEY"),
} "api_version": os.getenv("AZURE_API_VERSION"),
}, { "api_base": os.getenv("AZURE_API_BASE"),
"model_name": "gpt-3.5-turbo", },
"litellm_params": { # params for litellm completion/embedding call },
"model": "gpt-3.5-turbo", {
"api_key": os.getenv("OPENAI_API_KEY"), "model_name": "gpt-3.5-turbo",
} "litellm_params": { # params for litellm completion/embedding call
}] "model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
router = Router(model_list=model_list) router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"] file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = [] questions = []
for file_path in file_paths: for file_path in file_paths:
try: try:
print(file_path) print(file_path)
with open(file_path, 'r') as file: with open(file_path, "r") as file:
content = file.read() content = file.read()
questions.append(content) questions.append(content)
except FileNotFoundError as e: except FileNotFoundError as e:
@ -59,10 +68,9 @@ for file_path in file_paths:
# print(q) # print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere # Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions # show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
import concurrent.futures import concurrent.futures
import random import random
@ -75,7 +83,12 @@ def make_openai_completion(question):
start_time = time.time() start_time = time.time()
response = router.completion( response = router.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}], messages=[
{
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
}
],
) )
print(response) print(response)
end_time = time.time() end_time = time.time()
@ -90,11 +103,10 @@ def make_openai_completion(question):
except Exception as e: except Exception as e:
# Log exceptions for failed calls # Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file: with open("error_log.txt", "a") as error_log_file:
error_log_file.write( error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
return None return None
# Number of concurrent calls (you can adjust this) # Number of concurrent calls (you can adjust this)
concurrent_calls = 150 concurrent_calls = 150
@ -131,4 +143,3 @@ with open("request_log.txt", "r") as log_file:
with open("error_log.txt", "r") as error_log_file: with open("error_log.txt", "r") as error_log_file:
print("\nError Log:\n", error_log_file.read()) print("\nError Log:\n", error_log_file.read())

View file

@ -9,9 +9,15 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] callbacks: List[Callable] = []
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_input_callback: List[
_async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here. Callable
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. ] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[
Union[str, Callable]
] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[
Callable
] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
email: Optional[ email: Optional[
@ -42,20 +48,88 @@ aleph_alpha_key: Optional[str] = None
nlp_cloud_key: Optional[str] = None nlp_cloud_key: Optional[str] = None
use_client: bool = False use_client: bool = False
logging: bool = True logging: bool = True
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.ai/docs/caching cache: Optional[
Cache
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
model_alias_map: Dict[str, str] = {} model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers max_budget: float = 0.0 # set the max budget across all providers
_openai_completion_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] _openai_completion_params = [
_litellm_completion_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"] "functions",
_current_cost = 0 # private variable, used if max budget is set "function_call",
"temperature",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"request_timeout",
"api_base",
"api_version",
"api_key",
"deployment_id",
"organization",
"base_url",
"default_headers",
"timeout",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
]
_litellm_completion_params = [
"metadata",
"acompletion",
"caching",
"mock_response",
"api_key",
"api_version",
"api_base",
"force_timeout",
"logger_fn",
"verbose",
"custom_llm_provider",
"litellm_logging_obj",
"litellm_call_id",
"use_client",
"id",
"fallbacks",
"azure",
"headers",
"model_list",
"num_retries",
"context_window_fallback_dict",
"roles",
"final_prompt_value",
"bos_token",
"eos_token",
"request_timeout",
"complete_response",
"self",
"client",
"rpm",
"tpm",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",
"model_info",
"proxy_server_request",
"preset_cache_key",
]
_current_cost = 0 # private variable, used if max budget is set
error_logs: Dict = {} error_logs: Dict = {}
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
client_session: Optional[httpx.Client] = None client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
suppress_debug_info = False suppress_debug_info = False
dynamodb_table_name: Optional[str] = None dynamodb_table_name: Optional[str] = None
@ -66,23 +140,35 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None
allowed_fails: int = 0 allowed_fails: int = 0
####### SECRET MANAGERS ##################### ####### SECRET MANAGERS #####################
secret_manager_client: Optional[Any] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. secret_manager_client: Optional[
Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
############################################# #############################################
def get_model_cost_map(url: str): def get_model_cost_map(url: str):
try: try:
with requests.get(url, timeout=5) as response: # set a 5 second timeout for the get request with requests.get(
response.raise_for_status() # Raise an exception if the request is unsuccessful url, timeout=5
) as response: # set a 5 second timeout for the get request
response.raise_for_status() # Raise an exception if the request is unsuccessful
content = response.json() content = response.json()
return content return content
except Exception as e: except Exception as e:
import importlib.resources import importlib.resources
import json import json
with importlib.resources.open_text("litellm", "model_prices_and_context_window_backup.json") as f:
with importlib.resources.open_text(
"litellm", "model_prices_and_context_window_backup.json"
) as f:
content = json.load(f) content = json.load(f)
return content return content
model_cost = get_model_cost_map(url=model_cost_map_url) model_cost = get_model_cost_map(url=model_cost_map_url)
custom_prompt_dict:Dict[str, dict] = {} custom_prompt_dict: Dict[str, dict] = {}
####### THREAD-SPECIFIC DATA ################### ####### THREAD-SPECIFIC DATA ###################
class MyLocal(threading.local): class MyLocal(threading.local):
def __init__(self): def __init__(self):
@ -123,56 +209,51 @@ bedrock_models: List = []
deepinfra_models: List = [] deepinfra_models: List = []
perplexity_models: List = [] perplexity_models: List = []
for key, value in model_cost.items(): for key, value in model_cost.items():
if value.get('litellm_provider') == 'openai': if value.get("litellm_provider") == "openai":
open_ai_chat_completion_models.append(key) open_ai_chat_completion_models.append(key)
elif value.get('litellm_provider') == 'text-completion-openai': elif value.get("litellm_provider") == "text-completion-openai":
open_ai_text_completion_models.append(key) open_ai_text_completion_models.append(key)
elif value.get('litellm_provider') == 'cohere': elif value.get("litellm_provider") == "cohere":
cohere_models.append(key) cohere_models.append(key)
elif value.get('litellm_provider') == 'anthropic': elif value.get("litellm_provider") == "anthropic":
anthropic_models.append(key) anthropic_models.append(key)
elif value.get('litellm_provider') == 'openrouter': elif value.get("litellm_provider") == "openrouter":
openrouter_models.append(key) openrouter_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-text-models': elif value.get("litellm_provider") == "vertex_ai-text-models":
vertex_text_models.append(key) vertex_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-text-models': elif value.get("litellm_provider") == "vertex_ai-code-text-models":
vertex_code_text_models.append(key) vertex_code_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-language-models': elif value.get("litellm_provider") == "vertex_ai-language-models":
vertex_language_models.append(key) vertex_language_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-vision-models': elif value.get("litellm_provider") == "vertex_ai-vision-models":
vertex_vision_models.append(key) vertex_vision_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-chat-models': elif value.get("litellm_provider") == "vertex_ai-chat-models":
vertex_chat_models.append(key) vertex_chat_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models': elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
vertex_code_chat_models.append(key) vertex_code_chat_models.append(key)
elif value.get('litellm_provider') == 'ai21': elif value.get("litellm_provider") == "ai21":
ai21_models.append(key) ai21_models.append(key)
elif value.get('litellm_provider') == 'nlp_cloud': elif value.get("litellm_provider") == "nlp_cloud":
nlp_cloud_models.append(key) nlp_cloud_models.append(key)
elif value.get('litellm_provider') == 'aleph_alpha': elif value.get("litellm_provider") == "aleph_alpha":
aleph_alpha_models.append(key) aleph_alpha_models.append(key)
elif value.get('litellm_provider') == 'bedrock': elif value.get("litellm_provider") == "bedrock":
bedrock_models.append(key) bedrock_models.append(key)
elif value.get('litellm_provider') == 'deepinfra': elif value.get("litellm_provider") == "deepinfra":
deepinfra_models.append(key) deepinfra_models.append(key)
elif value.get('litellm_provider') == 'perplexity': elif value.get("litellm_provider") == "perplexity":
perplexity_models.append(key) perplexity_models.append(key)
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
openai_compatible_endpoints: List = [ openai_compatible_endpoints: List = [
"api.perplexity.ai", "api.perplexity.ai",
"api.endpoints.anyscale.com/v1", "api.endpoints.anyscale.com/v1",
"api.deepinfra.com/v1/openai", "api.deepinfra.com/v1/openai",
"api.mistral.ai/v1" "api.mistral.ai/v1",
] ]
# this is maintained for Exception Mapping # this is maintained for Exception Mapping
openai_compatible_providers: List = [ openai_compatible_providers: List = ["anyscale", "mistral", "deepinfra", "perplexity"]
"anyscale",
"mistral",
"deepinfra",
"perplexity"
]
# well supported replicate llms # well supported replicate llms
@ -209,23 +290,18 @@ huggingface_models: List = [
together_ai_models: List = [ together_ai_models: List = [
# llama llms - chat # llama llms - chat
"togethercomputer/llama-2-70b-chat", "togethercomputer/llama-2-70b-chat",
# llama llms - language / instruct
# llama llms - language / instruct
"togethercomputer/llama-2-70b", "togethercomputer/llama-2-70b",
"togethercomputer/LLaMA-2-7B-32K", "togethercomputer/LLaMA-2-7B-32K",
"togethercomputer/Llama-2-7B-32K-Instruct", "togethercomputer/Llama-2-7B-32K-Instruct",
"togethercomputer/llama-2-7b", "togethercomputer/llama-2-7b",
# falcon llms # falcon llms
"togethercomputer/falcon-40b-instruct", "togethercomputer/falcon-40b-instruct",
"togethercomputer/falcon-7b-instruct", "togethercomputer/falcon-7b-instruct",
# alpaca # alpaca
"togethercomputer/alpaca-7b", "togethercomputer/alpaca-7b",
# chat llms # chat llms
"HuggingFaceH4/starchat-alpha", "HuggingFaceH4/starchat-alpha",
# code llms # code llms
"togethercomputer/CodeLlama-34b", "togethercomputer/CodeLlama-34b",
"togethercomputer/CodeLlama-34b-Instruct", "togethercomputer/CodeLlama-34b-Instruct",
@ -234,29 +310,27 @@ together_ai_models: List = [
"NumbersStation/nsql-llama-2-7B", "NumbersStation/nsql-llama-2-7B",
"WizardLM/WizardCoder-15B-V1.0", "WizardLM/WizardCoder-15B-V1.0",
"WizardLM/WizardCoder-Python-34B-V1.0", "WizardLM/WizardCoder-Python-34B-V1.0",
# language llms # language llms
"NousResearch/Nous-Hermes-Llama2-13b", "NousResearch/Nous-Hermes-Llama2-13b",
"Austism/chronos-hermes-13b", "Austism/chronos-hermes-13b",
"upstage/SOLAR-0-70b-16bit", "upstage/SOLAR-0-70b-16bit",
"WizardLM/WizardLM-70B-V1.0", "WizardLM/WizardLM-70B-V1.0",
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
baseten_models: List = ["qvv0xeq", "q841o8w", "31dxrj3"] # FALCON 7B # WizardLM # Mosaic ML baseten_models: List = [
"qvv0xeq",
"q841o8w",
"31dxrj3",
] # FALCON 7B # WizardLM # Mosaic ML
petals_models = [ petals_models = [
"petals-team/StableBeluga2", "petals-team/StableBeluga2",
] ]
ollama_models = [ ollama_models = ["llama2"]
"llama2"
]
maritalk_models = [ maritalk_models = ["maritalk"]
"maritalk"
]
model_list = ( model_list = (
open_ai_chat_completion_models open_ai_chat_completion_models
@ -308,7 +382,7 @@ provider_list: List = [
"anyscale", "anyscale",
"mistral", "mistral",
"maritalk", "maritalk",
"custom", # custom apis "custom", # custom apis
] ]
models_by_provider: dict = { models_by_provider: dict = {
@ -327,28 +401,28 @@ models_by_provider: dict = {
"ollama": ollama_models, "ollama": ollama_models,
"deepinfra": deepinfra_models, "deepinfra": deepinfra_models,
"perplexity": perplexity_models, "perplexity": perplexity_models,
"maritalk": maritalk_models "maritalk": maritalk_models,
} }
# mapping for those models which have larger equivalents # mapping for those models which have larger equivalents
longer_context_model_fallback_dict: dict = { longer_context_model_fallback_dict: dict = {
# openai chat completion models # openai chat completion models
"gpt-3.5-turbo": "gpt-3.5-turbo-16k", "gpt-3.5-turbo": "gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301", "gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301",
"gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613",
"gpt-4": "gpt-4-32k", "gpt-4": "gpt-4-32k",
"gpt-4-0314": "gpt-4-32k-0314", "gpt-4-0314": "gpt-4-32k-0314",
"gpt-4-0613": "gpt-4-32k-0613", "gpt-4-0613": "gpt-4-32k-0613",
# anthropic # anthropic
"claude-instant-1": "claude-2", "claude-instant-1": "claude-2",
"claude-instant-1.2": "claude-2", "claude-instant-1.2": "claude-2",
# vertexai # vertexai
"chat-bison": "chat-bison-32k", "chat-bison": "chat-bison-32k",
"chat-bison@001": "chat-bison-32k", "chat-bison@001": "chat-bison-32k",
"codechat-bison": "codechat-bison-32k", "codechat-bison": "codechat-bison-32k",
"codechat-bison@001": "codechat-bison-32k", "codechat-bison@001": "codechat-bison-32k",
# openrouter # openrouter
"openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k", "openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k",
"openrouter/anthropic/claude-instant-v1": "openrouter/anthropic/claude-2", "openrouter/anthropic/claude-instant-v1": "openrouter/anthropic/claude-2",
} }
@ -357,20 +431,23 @@ open_ai_embedding_models: List = ["text-embedding-ada-002"]
cohere_embedding_models: List = [ cohere_embedding_models: List = [
"embed-english-v3.0", "embed-english-v3.0",
"embed-english-light-v3.0", "embed-english-light-v3.0",
"embed-multilingual-v3.0", "embed-multilingual-v3.0",
"embed-english-v2.0", "embed-english-v2.0",
"embed-english-light-v2.0", "embed-english-light-v2.0",
"embed-multilingual-v2.0", "embed-multilingual-v2.0",
]
bedrock_embedding_models: List = [
"amazon.titan-embed-text-v1",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
] ]
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models all_embedding_models = (
open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
)
####### IMAGE GENERATION MODELS ################### ####### IMAGE GENERATION MODELS ###################
openai_image_generation_models = [ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
"dall-e-2",
"dall-e-3"
]
from .timeout import timeout from .timeout import timeout
@ -394,11 +471,11 @@ from .utils import (
get_llm_provider, get_llm_provider,
completion_with_config, completion_with_config,
register_model, register_model,
encode, encode,
decode, decode,
_calculate_retry_after, _calculate_retry_after,
_should_retry, _should_retry,
get_secret get_secret,
) )
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig
@ -415,7 +492,13 @@ from .llms.vertex_ai import VertexAIConfig
from .llms.sagemaker import SagemakerConfig from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig, AmazonLlamaConfig from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,
AmazonAnthropicConfig,
AmazonCohereConfig,
AmazonLlamaConfig,
)
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig from .llms.azure import AzureOpenAIConfig
from .main import * # type: ignore from .main import * # type: ignore
@ -429,13 +512,13 @@ from .exceptions import (
ServiceUnavailableError, ServiceUnavailableError,
OpenAIError, OpenAIError,
ContextWindowExceededError, ContextWindowExceededError,
BudgetExceededError, BudgetExceededError,
APIError, APIError,
Timeout, Timeout,
APIConnectionError, APIConnectionError,
APIResponseValidationError, APIResponseValidationError,
UnprocessableEntityError UnprocessableEntityError,
) )
from .budget_manager import BudgetManager from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server from .proxy.proxy_cli import run_server
from .router import Router from .router import Router

View file

@ -1,8 +1,9 @@
set_verbose = False set_verbose = False
def print_verbose(print_statement): def print_verbose(print_statement):
try: try:
if set_verbose: if set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
except: except:
pass pass

View file

@ -13,6 +13,7 @@ import inspect
import redis, litellm import redis, litellm
from typing import List, Optional from typing import List, Optional
def _get_redis_kwargs(): def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis) arg_spec = inspect.getfullargspec(redis.Redis)
@ -23,32 +24,26 @@ def _get_redis_kwargs():
"retry", "retry",
} }
include_args = ["url"]
include_args = [
"url"
]
available_args = [ available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
x for x in arg_spec.args if x not in exclude_args
] + include_args
return available_args return available_args
def _get_redis_env_kwarg_mapping(): def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_" PREFIX = "REDIS_"
return { return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()
}
def _redis_kwargs_from_environment(): def _redis_kwargs_from_environment():
mapping = _get_redis_env_kwarg_mapping() mapping = _get_redis_env_kwarg_mapping()
return_dict = {} return_dict = {}
for k, v in mapping.items(): for k, v in mapping.items():
value = litellm.get_secret(k, default_value=None) # check os.environ/key vault value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
if value is not None: if value is not None:
return_dict[v] = value return_dict[v] = value
return return_dict return return_dict
@ -56,21 +51,26 @@ def _redis_kwargs_from_environment():
def get_redis_url_from_environment(): def get_redis_url_from_environment():
if "REDIS_URL" in os.environ: if "REDIS_URL" in os.environ:
return os.environ["REDIS_URL"] return os.environ["REDIS_URL"]
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ: if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.") raise ValueError(
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
)
if "REDIS_PASSWORD" in os.environ: if "REDIS_PASSWORD" in os.environ:
redis_password = f":{os.environ['REDIS_PASSWORD']}@" redis_password = f":{os.environ['REDIS_PASSWORD']}@"
else: else:
redis_password = "" redis_password = ""
return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}" return (
f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
)
def get_redis_client(**env_overrides): def get_redis_client(**env_overrides):
### check if "os.environ/<key-name>" passed in ### check if "os.environ/<key-name>" passed in
for k, v in env_overrides.items(): for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
v = v.replace("os.environ/", "") v = v.replace("os.environ/", "")
value = litellm.get_secret(v) value = litellm.get_secret(v)
env_overrides[k] = value env_overrides[k] = value
@ -80,14 +80,14 @@ def get_redis_client(**env_overrides):
**env_overrides, **env_overrides,
} }
if "url" in redis_kwargs and redis_kwargs['url'] is not None: if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop("host", None) redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None) redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None) redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None) redis_kwargs.pop("password", None)
return redis.Redis.from_url(**redis_kwargs) return redis.Redis.from_url(**redis_kwargs)
elif "host" not in redis_kwargs or redis_kwargs['host'] is None: elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.") raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis.Redis(**redis_kwargs) return redis.Redis(**redis_kwargs)

View file

@ -1,119 +1,166 @@
import os, json, time import os, json, time
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
import requests, threading import requests, threading
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
class BudgetManager: class BudgetManager:
def __init__(self, project_name: str, client_type: str = "local", api_base: Optional[str] = None): def __init__(
self,
project_name: str,
client_type: str = "local",
api_base: Optional[str] = None,
):
self.client_type = client_type self.client_type = client_type
self.project_name = project_name self.project_name = project_name
self.api_base = api_base or "https://api.litellm.ai" self.api_base = api_base or "https://api.litellm.ai"
## load the data or init the initial dictionaries ## load the data or init the initial dictionaries
self.load_data() self.load_data()
def print_verbose(self, print_statement): def print_verbose(self, print_statement):
try: try:
if litellm.set_verbose: if litellm.set_verbose:
import logging import logging
logging.info(print_statement) logging.info(print_statement)
except: except:
pass pass
def load_data(self): def load_data(self):
if self.client_type == "local": if self.client_type == "local":
# Check if user dict file exists # Check if user dict file exists
if os.path.isfile("user_cost.json"): if os.path.isfile("user_cost.json"):
# Load the user dict # Load the user dict
with open("user_cost.json", 'r') as json_file: with open("user_cost.json", "r") as json_file:
self.user_dict = json.load(json_file) self.user_dict = json.load(json_file)
else: else:
self.print_verbose("User Dictionary not found!") self.print_verbose("User Dictionary not found!")
self.user_dict = {} self.user_dict = {}
self.print_verbose(f"user dict from local: {self.user_dict}") self.print_verbose(f"user dict from local: {self.user_dict}")
elif self.client_type == "hosted": elif self.client_type == "hosted":
# Load the user_dict from hosted db # Load the user_dict from hosted db
url = self.api_base + "/get_budget" url = self.api_base + "/get_budget"
headers = {'Content-Type': 'application/json'} headers = {"Content-Type": "application/json"}
data = { data = {"project_name": self.project_name}
'project_name' : self.project_name
}
response = requests.post(url, headers=headers, json=data) response = requests.post(url, headers=headers, json=data)
response = response.json() response = response.json()
if response["status"] == "error": if response["status"] == "error":
self.user_dict = {} # assume this means the user dict hasn't been stored yet self.user_dict = (
{}
) # assume this means the user dict hasn't been stored yet
else: else:
self.user_dict = response["data"] self.user_dict = response["data"]
def create_budget(self, total_budget: float, user: str, duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, created_at: float = time.time()): def create_budget(
self,
total_budget: float,
user: str,
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
created_at: float = time.time(),
):
self.user_dict[user] = {"total_budget": total_budget} self.user_dict[user] = {"total_budget": total_budget}
if duration is None: if duration is None:
return self.user_dict[user] return self.user_dict[user]
if duration == 'daily': if duration == "daily":
duration_in_days = 1 duration_in_days = 1
elif duration == 'weekly': elif duration == "weekly":
duration_in_days = 7 duration_in_days = 7
elif duration == 'monthly': elif duration == "monthly":
duration_in_days = 28 duration_in_days = 28
elif duration == 'yearly': elif duration == "yearly":
duration_in_days = 365 duration_in_days = 365
else: else:
raise ValueError("""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""") raise ValueError(
self.user_dict[user] = {"total_budget": total_budget, "duration": duration_in_days, "created_at": created_at, "last_updated_at": created_at} """duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution )
self.user_dict[user] = {
"total_budget": total_budget,
"duration": duration_in_days,
"created_at": created_at,
"last_updated_at": created_at,
}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return self.user_dict[user] return self.user_dict[user]
def projected_cost(self, model: str, messages: list, user: str): def projected_cost(self, model: str, messages: list, user: str):
text = "".join(message["content"] for message in messages) text = "".join(message["content"] for message in messages)
prompt_tokens = litellm.token_counter(model=model, text=text) prompt_tokens = litellm.token_counter(model=model, text=text)
prompt_cost, _ = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=0) prompt_cost, _ = litellm.cost_per_token(
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
)
current_cost = self.user_dict[user].get("current_cost", 0) current_cost = self.user_dict[user].get("current_cost", 0)
projected_cost = prompt_cost + current_cost projected_cost = prompt_cost + current_cost
return projected_cost return projected_cost
def get_total_budget(self, user: str): def get_total_budget(self, user: str):
return self.user_dict[user]["total_budget"] return self.user_dict[user]["total_budget"]
def update_cost(self, user: str, completion_obj: Optional[ModelResponse] = None, model: Optional[str] = None, input_text: Optional[str] = None, output_text: Optional[str] = None):
if model and input_text and output_text:
prompt_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": input_text}])
completion_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": output_text}])
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
elif completion_obj:
cost = litellm.completion_cost(completion_response=completion_obj)
model = completion_obj['model'] # if this throws an error try, model = completion_obj['model']
else:
raise ValueError("Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager")
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get("current_cost", 0) def update_cost(
self,
user: str,
completion_obj: Optional[ModelResponse] = None,
model: Optional[str] = None,
input_text: Optional[str] = None,
output_text: Optional[str] = None,
):
if model and input_text and output_text:
prompt_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": input_text}]
)
completion_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": output_text}]
)
(
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
) = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
elif completion_obj:
cost = litellm.completion_cost(completion_response=completion_obj)
model = completion_obj[
"model"
] # if this throws an error try, model = completion_obj['model']
else:
raise ValueError(
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
)
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
"current_cost", 0
)
if "model_cost" in self.user_dict[user]: if "model_cost" in self.user_dict[user]:
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user]["model_cost"].get(model, 0) self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
"model_cost"
].get(model, 0)
else: else:
self.user_dict[user]["model_cost"] = {model: cost} self.user_dict[user]["model_cost"] = {model: cost}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return {"user": self.user_dict[user]} return {"user": self.user_dict[user]}
def get_current_cost(self, user): def get_current_cost(self, user):
return self.user_dict[user].get("current_cost", 0) return self.user_dict[user].get("current_cost", 0)
def get_model_cost(self, user): def get_model_cost(self, user):
return self.user_dict[user].get("model_cost", 0) return self.user_dict[user].get("model_cost", 0)
def is_valid_user(self, user: str) -> bool: def is_valid_user(self, user: str) -> bool:
return user in self.user_dict return user in self.user_dict
def get_users(self): def get_users(self):
return list(self.user_dict.keys()) return list(self.user_dict.keys())
def reset_cost(self, user): def reset_cost(self, user):
self.user_dict[user]["current_cost"] = 0 self.user_dict[user]["current_cost"] = 0
self.user_dict[user]["model_cost"] = {} self.user_dict[user]["model_cost"] = {}
return {"user": self.user_dict[user]} return {"user": self.user_dict[user]}
def reset_on_duration(self, user: str): def reset_on_duration(self, user: str):
# Get current and creation time # Get current and creation time
last_updated_at = self.user_dict[user]["last_updated_at"] last_updated_at = self.user_dict[user]["last_updated_at"]
@ -121,38 +168,39 @@ class BudgetManager:
# Convert duration from days to seconds # Convert duration from days to seconds
duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60 duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60
# Check if duration has elapsed # Check if duration has elapsed
if current_time - last_updated_at >= duration_in_seconds: if current_time - last_updated_at >= duration_in_seconds:
# Reset cost if duration has elapsed and update the creation time # Reset cost if duration has elapsed and update the creation time
self.reset_cost(user) self.reset_cost(user)
self.user_dict[user]["last_updated_at"] = current_time self.user_dict[user]["last_updated_at"] = current_time
self._save_data_thread() # Save the data self._save_data_thread() # Save the data
def update_budget_all_users(self): def update_budget_all_users(self):
for user in self.get_users(): for user in self.get_users():
if "duration" in self.user_dict[user]: if "duration" in self.user_dict[user]:
self.reset_on_duration(user) self.reset_on_duration(user)
def _save_data_thread(self): def _save_data_thread(self):
thread = threading.Thread(target=self.save_data) # [Non-Blocking]: saves data without blocking execution thread = threading.Thread(
target=self.save_data
) # [Non-Blocking]: saves data without blocking execution
thread.start() thread.start()
def save_data(self): def save_data(self):
if self.client_type == "local": if self.client_type == "local":
import json import json
# save the user dict # save the user dict
with open("user_cost.json", 'w') as json_file: with open("user_cost.json", "w") as json_file:
json.dump(self.user_dict, json_file, indent=4) # Indent for pretty formatting json.dump(
self.user_dict, json_file, indent=4
) # Indent for pretty formatting
return {"status": "success"} return {"status": "success"}
elif self.client_type == "hosted": elif self.client_type == "hosted":
url = self.api_base + "/set_budget" url = self.api_base + "/set_budget"
headers = {'Content-Type': 'application/json'} headers = {"Content-Type": "application/json"}
data = { data = {"project_name": self.project_name, "user_dict": self.user_dict}
'project_name' : self.project_name,
"user_dict": self.user_dict
}
response = requests.post(url, headers=headers, json=data) response = requests.post(url, headers=headers, json=data)
response = response.json() response = response.json()
return response return response

View file

@ -12,13 +12,15 @@ import time, logging
import json, traceback, ast import json, traceback, ast
from typing import Optional, Literal, List from typing import Optional, Literal, List
def print_verbose(print_statement): def print_verbose(print_statement):
try: try:
if litellm.set_verbose: if litellm.set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
except: except:
pass pass
class BaseCache: class BaseCache:
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
raise NotImplementedError raise NotImplementedError
@ -45,13 +47,13 @@ class InMemoryCache(BaseCache):
self.cache_dict.pop(key, None) self.cache_dict.pop(key, None)
return None return None
original_cached_response = self.cache_dict[key] original_cached_response = self.cache_dict[key]
try: try:
cached_response = json.loads(original_cached_response) cached_response = json.loads(original_cached_response)
except: except:
cached_response = original_cached_response cached_response = original_cached_response
return cached_response return cached_response
return None return None
def flush_cache(self): def flush_cache(self):
self.cache_dict.clear() self.cache_dict.clear()
self.ttl_dict.clear() self.ttl_dict.clear()
@ -60,17 +62,18 @@ class InMemoryCache(BaseCache):
class RedisCache(BaseCache): class RedisCache(BaseCache):
def __init__(self, host=None, port=None, password=None, **kwargs): def __init__(self, host=None, port=None, password=None, **kwargs):
import redis import redis
# if users don't provider one, use the default litellm cache # if users don't provider one, use the default litellm cache
from ._redis import get_redis_client from ._redis import get_redis_client
redis_kwargs = {} redis_kwargs = {}
if host is not None: if host is not None:
redis_kwargs["host"] = host redis_kwargs["host"] = host
if port is not None: if port is not None:
redis_kwargs["port"] = port redis_kwargs["port"] = port
if password is not None: if password is not None:
redis_kwargs["password"] = password redis_kwargs["password"] = password
redis_kwargs.update(kwargs) redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs) self.redis_client = get_redis_client(**redis_kwargs)
@ -88,13 +91,19 @@ class RedisCache(BaseCache):
try: try:
print_verbose(f"Get Redis Cache: key: {key}") print_verbose(f"Get Redis Cache: key: {key}")
cached_response = self.redis_client.get(key) cached_response = self.redis_client.get(key)
print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}") print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
)
if cached_response != None: if cached_response != None:
# cached_response is in `b{} convert it to ModelResponse # cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string cached_response = cached_response.decode(
try: "utf-8"
cached_response = json.loads(cached_response) # Convert string to dictionary ) # Convert bytes to string
except: try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response) cached_response = ast.literal_eval(cached_response)
return cached_response return cached_response
except Exception as e: except Exception as e:
@ -105,34 +114,40 @@ class RedisCache(BaseCache):
def flush_cache(self): def flush_cache(self):
self.redis_client.flushall() self.redis_client.flushall()
class DualCache(BaseCache):
class DualCache(BaseCache):
""" """
This updates both Redis and an in-memory cache simultaneously. This updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis. When data is updated or inserted, it is written to both the in-memory cache + Redis.
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data. This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
""" """
def __init__(self, in_memory_cache: Optional[InMemoryCache] =None, redis_cache: Optional[RedisCache] =None) -> None:
def __init__(
self,
in_memory_cache: Optional[InMemoryCache] = None,
redis_cache: Optional[RedisCache] = None,
) -> None:
super().__init__() super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache # If in_memory_cache is not provided, use the default InMemoryCache
self.in_memory_cache = in_memory_cache or InMemoryCache() self.in_memory_cache = in_memory_cache or InMemoryCache()
# If redis_cache is not provided, use the default RedisCache # If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache self.redis_cache = redis_cache
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
# Update both Redis and in-memory cache # Update both Redis and in-memory cache
try: try:
print_verbose(f"set cache: key: {key}; value: {value}") print_verbose(f"set cache: key: {key}; value: {value}")
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
self.in_memory_cache.set_cache(key, value, **kwargs) self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None: if self.redis_cache is not None:
self.redis_cache.set_cache(key, value, **kwargs) self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e: except Exception as e:
print_verbose(e) print_verbose(e)
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
# Try to fetch from in-memory cache first # Try to fetch from in-memory cache first
try: try:
print_verbose(f"get cache: cache key: {key}") print_verbose(f"get cache: cache key: {key}")
result = None result = None
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
@ -141,7 +156,7 @@ class DualCache(BaseCache):
if in_memory_result is not None: if in_memory_result is not None:
result = in_memory_result result = in_memory_result
if self.redis_cache is not None: if self.redis_cache is not None:
# If not found in in-memory cache, try fetching from Redis # If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs) redis_result = self.redis_cache.get_cache(key, **kwargs)
@ -153,25 +168,28 @@ class DualCache(BaseCache):
print_verbose(f"get cache: cache result: {result}") print_verbose(f"get cache: cache result: {result}")
return result return result
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
def flush_cache(self): def flush_cache(self):
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache() self.in_memory_cache.flush_cache()
if self.redis_cache is not None: if self.redis_cache is not None:
self.redis_cache.flush_cache() self.redis_cache.flush_cache()
#### LiteLLM.Completion / Embedding Cache #### #### LiteLLM.Completion / Embedding Cache ####
class Cache: class Cache:
def __init__( def __init__(
self, self,
type: Optional[Literal["local", "redis"]] = "local", type: Optional[Literal["local", "redis"]] = "local",
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
supported_call_types: Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"], supported_call_types: Optional[
**kwargs List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
**kwargs,
): ):
""" """
Initializes the cache based on the given type. Initializes the cache based on the given type.
@ -200,7 +218,7 @@ class Cache:
litellm.success_callback.append("cache") litellm.success_callback.append("cache")
if "cache" not in litellm._async_success_callback: if "cache" not in litellm._async_success_callback:
litellm._async_success_callback.append("cache") litellm._async_success_callback.append("cache")
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"] self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
def get_cache_key(self, *args, **kwargs): def get_cache_key(self, *args, **kwargs):
""" """
@ -215,18 +233,37 @@ class Cache:
""" """
cache_key = "" cache_key = ""
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}") print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens # for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None: if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
print_verbose(f"\nReturning preset cache key: {cache_key}") print_verbose(f"\nReturning preset cache key: {cache_key}")
return kwargs.get("litellm_params", {}).get("preset_cache_key", None) return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4] # sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"] completion_kwargs = [
embedding_only_kwargs = ["input", "encoding_format"] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs "model",
"messages",
"temperature",
"top_p",
"n",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
]
embedding_only_kwargs = [
"input",
"encoding_format",
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set() # combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = completion_kwargs + embedding_only_kwargs combined_kwargs = completion_kwargs + embedding_only_kwargs
for param in combined_kwargs: for param in combined_kwargs:
# ignore litellm params here # ignore litellm params here
if param in kwargs: if param in kwargs:
@ -241,8 +278,8 @@ class Cache:
model_group = metadata.get("model_group", None) model_group = metadata.get("model_group", None)
caching_groups = metadata.get("caching_groups", None) caching_groups = metadata.get("caching_groups", None)
if caching_groups: if caching_groups:
for group in caching_groups: for group in caching_groups:
if model_group in group: if model_group in group:
caching_group = group caching_group = group
break break
if litellm_params is not None: if litellm_params is not None:
@ -251,23 +288,34 @@ class Cache:
model_group = metadata.get("model_group", None) model_group = metadata.get("model_group", None)
caching_groups = metadata.get("caching_groups", None) caching_groups = metadata.get("caching_groups", None)
if caching_groups: if caching_groups:
for group in caching_groups: for group in caching_groups:
if model_group in group: if model_group in group:
caching_group = group caching_group = group
break break
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"] param_value = (
caching_group or model_group or kwargs[param]
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
else: else:
if kwargs[param] is None: if kwargs[param] is None:
continue # ignore None params continue # ignore None params
param_value = kwargs[param] param_value = kwargs[param]
cache_key+= f"{str(param)}: {str(param_value)}" cache_key += f"{str(param)}: {str(param_value)}"
print_verbose(f"\nCreated cache key: {cache_key}") print_verbose(f"\nCreated cache key: {cache_key}")
return cache_key return cache_key
def generate_streaming_content(self, content): def generate_streaming_content(self, content):
chunk_size = 5 # Adjust the chunk size as needed chunk_size = 5 # Adjust the chunk size as needed
for i in range(0, len(content), chunk_size): for i in range(0, len(content), chunk_size):
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]} yield {
"choices": [
{
"delta": {
"role": "assistant",
"content": content[i : i + chunk_size],
}
}
]
}
time.sleep(0.02) time.sleep(0.02)
def get_cache(self, *args, **kwargs): def get_cache(self, *args, **kwargs):
@ -319,4 +367,4 @@ class Cache:
pass pass
async def _async_add_cache(self, result, *args, **kwargs): async def _async_add_cache(self, result, *args, **kwargs):
self.add_cache(result, *args, **kwargs) self.add_cache(result, *args, **kwargs)

View file

@ -1,2 +1,2 @@
# from .main import * # from .main import *
# from .server_utils import * # from .server_utils import *

View file

@ -33,7 +33,7 @@
# llm_model_list: Optional[list] = None # llm_model_list: Optional[list] = None
# server_settings: Optional[dict] = None # server_settings: Optional[dict] = None
# set_callbacks() # sets litellm callbacks for logging if they exist in the environment # set_callbacks() # sets litellm callbacks for logging if they exist in the environment
# if "CONFIG_FILE_PATH" in os.environ: # if "CONFIG_FILE_PATH" in os.environ:
# llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH")) # llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
@ -44,7 +44,7 @@
# @router.get("/models") # if project requires model list # @router.get("/models") # if project requires model list
# def model_list(): # def model_list():
# all_models = litellm.utils.get_valid_models() # all_models = litellm.utils.get_valid_models()
# if llm_model_list: # if llm_model_list:
# all_models += llm_model_list # all_models += llm_model_list
# return dict( # return dict(
# data=[ # data=[
@ -79,8 +79,8 @@
# @router.post("/v1/embeddings") # @router.post("/v1/embeddings")
# @router.post("/embeddings") # @router.post("/embeddings")
# async def embedding(request: Request): # async def embedding(request: Request):
# try: # try:
# data = await request.json() # data = await request.json()
# # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers # # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
# if os.getenv("AUTH_STRATEGY", None) == "DYNAMIC" and "authorization" in request.headers: # if users pass LLM api keys as part of header # if os.getenv("AUTH_STRATEGY", None) == "DYNAMIC" and "authorization" in request.headers: # if users pass LLM api keys as part of header
# api_key = request.headers.get("authorization") # api_key = request.headers.get("authorization")
@ -106,13 +106,13 @@
# data = await request.json() # data = await request.json()
# server_model = server_settings.get("completion_model", None) if server_settings else None # server_model = server_settings.get("completion_model", None) if server_settings else None
# data["model"] = server_model or model or data["model"] # data["model"] = server_model or model or data["model"]
# ## CHECK KEYS ## # ## CHECK KEYS ##
# # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers # # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
# # env_validation = litellm.validate_environment(model=data["model"]) # # env_validation = litellm.validate_environment(model=data["model"])
# # if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header # # if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
# # if "authorization" in request.headers: # # if "authorization" in request.headers:
# # api_key = request.headers.get("authorization") # # api_key = request.headers.get("authorization")
# # elif "api-key" in request.headers: # # elif "api-key" in request.headers:
# # api_key = request.headers.get("api-key") # # api_key = request.headers.get("api-key")
# # print(f"api_key in headers: {api_key}") # # print(f"api_key in headers: {api_key}")
# # if " " in api_key: # # if " " in api_key:
@ -122,11 +122,11 @@
# # api_key = api_key # # api_key = api_key
# # data["api_key"] = api_key # # data["api_key"] = api_key
# # print(f"api_key in data: {api_key}") # # print(f"api_key in data: {api_key}")
# ## CHECK CONFIG ## # ## CHECK CONFIG ##
# if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]: # if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]:
# for m in llm_model_list: # for m in llm_model_list:
# if data["model"] == m["model_name"]: # if data["model"] == m["model_name"]:
# for key, value in m["litellm_params"].items(): # for key, value in m["litellm_params"].items():
# data[key] = value # data[key] = value
# break # break
# response = litellm.completion( # response = litellm.completion(
@ -145,21 +145,21 @@
# @router.post("/router/completions") # @router.post("/router/completions")
# async def router_completion(request: Request): # async def router_completion(request: Request):
# global llm_router # global llm_router
# try: # try:
# data = await request.json() # data = await request.json()
# if "model_list" in data: # if "model_list" in data:
# llm_router = litellm.Router(model_list=data.pop("model_list")) # llm_router = litellm.Router(model_list=data.pop("model_list"))
# if llm_router is None: # if llm_router is None:
# raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body") # raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
# # openai.ChatCompletion.create replacement # # openai.ChatCompletion.create replacement
# response = await llm_router.acompletion(model="gpt-3.5-turbo", # response = await llm_router.acompletion(model="gpt-3.5-turbo",
# messages=[{"role": "user", "content": "Hey, how's it going?"}]) # messages=[{"role": "user", "content": "Hey, how's it going?"}])
# if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses # if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
# return StreamingResponse(data_generator(response), media_type='text/event-stream') # return StreamingResponse(data_generator(response), media_type='text/event-stream')
# return response # return response
# except Exception as e: # except Exception as e:
# error_traceback = traceback.format_exc() # error_traceback = traceback.format_exc()
# error_msg = f"{str(e)}\n\n{error_traceback}" # error_msg = f"{str(e)}\n\n{error_traceback}"
# return {"error": error_msg} # return {"error": error_msg}
@ -167,11 +167,11 @@
# @router.post("/router/embedding") # @router.post("/router/embedding")
# async def router_embedding(request: Request): # async def router_embedding(request: Request):
# global llm_router # global llm_router
# try: # try:
# data = await request.json() # data = await request.json()
# if "model_list" in data: # if "model_list" in data:
# llm_router = litellm.Router(model_list=data.pop("model_list")) # llm_router = litellm.Router(model_list=data.pop("model_list"))
# if llm_router is None: # if llm_router is None:
# raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body") # raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
# response = await llm_router.aembedding(model="gpt-3.5-turbo", # type: ignore # response = await llm_router.aembedding(model="gpt-3.5-turbo", # type: ignore
@ -180,7 +180,7 @@
# if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses # if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
# return StreamingResponse(data_generator(response), media_type='text/event-stream') # return StreamingResponse(data_generator(response), media_type='text/event-stream')
# return response # return response
# except Exception as e: # except Exception as e:
# error_traceback = traceback.format_exc() # error_traceback = traceback.format_exc()
# error_msg = f"{str(e)}\n\n{error_traceback}" # error_msg = f"{str(e)}\n\n{error_traceback}"
# return {"error": error_msg} # return {"error": error_msg}
@ -190,4 +190,4 @@
# return "LiteLLM: RUNNING" # return "LiteLLM: RUNNING"
# app.include_router(router) # app.include_router(router)

View file

@ -3,7 +3,7 @@
# import dotenv # import dotenv
# dotenv.load_dotenv() # load env variables # dotenv.load_dotenv() # load env variables
# def print_verbose(print_statement): # def print_verbose(print_statement):
# pass # pass
# def get_package_version(package_name): # def get_package_version(package_name):
@ -27,32 +27,31 @@
# def set_callbacks(): # def set_callbacks():
# ## LOGGING # ## LOGGING
# if len(os.getenv("SET_VERBOSE", "")) > 0: # if len(os.getenv("SET_VERBOSE", "")) > 0:
# if os.getenv("SET_VERBOSE") == "True": # if os.getenv("SET_VERBOSE") == "True":
# litellm.set_verbose = True # litellm.set_verbose = True
# print_verbose("\033[92mLiteLLM: Switched on verbose logging\033[0m") # print_verbose("\033[92mLiteLLM: Switched on verbose logging\033[0m")
# else: # else:
# litellm.set_verbose = False # litellm.set_verbose = False
# ### LANGFUSE # ### LANGFUSE
# if (len(os.getenv("LANGFUSE_PUBLIC_KEY", "")) > 0 and len(os.getenv("LANGFUSE_SECRET_KEY", ""))) > 0 or len(os.getenv("LANGFUSE_HOST", "")) > 0: # if (len(os.getenv("LANGFUSE_PUBLIC_KEY", "")) > 0 and len(os.getenv("LANGFUSE_SECRET_KEY", ""))) > 0 or len(os.getenv("LANGFUSE_HOST", "")) > 0:
# litellm.success_callback = ["langfuse"] # litellm.success_callback = ["langfuse"]
# print_verbose("\033[92mLiteLLM: Switched on Langfuse feature\033[0m") # print_verbose("\033[92mLiteLLM: Switched on Langfuse feature\033[0m")
# ## CACHING # ## CACHING
# ### REDIS # ### REDIS
# # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0: # # if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
# # print(f"redis host: {os.getenv('REDIS_HOST')}; redis port: {os.getenv('REDIS_PORT')}; password: {os.getenv('REDIS_PASSWORD')}") # # print(f"redis host: {os.getenv('REDIS_HOST')}; redis port: {os.getenv('REDIS_PORT')}; password: {os.getenv('REDIS_PASSWORD')}")
# # from litellm.caching import Cache # # from litellm.caching import Cache
# # litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) # # litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
# # print("\033[92mLiteLLM: Switched on Redis caching\033[0m") # # print("\033[92mLiteLLM: Switched on Redis caching\033[0m")
# def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'): # def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'):
# config = {} # config = {}
# server_settings = {} # server_settings = {}
# try: # try:
# if os.path.exists(config_file_path): # type: ignore # if os.path.exists(config_file_path): # type: ignore
# with open(config_file_path, 'r') as file: # type: ignore # with open(config_file_path, 'r') as file: # type: ignore
# config = yaml.safe_load(file) # config = yaml.safe_load(file)
@ -63,24 +62,24 @@
# ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral') # ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
# server_settings = config.get("server_settings", None) # server_settings = config.get("server_settings", None)
# if server_settings: # if server_settings:
# server_settings = server_settings # server_settings = server_settings
# ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) # ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
# litellm_settings = config.get('litellm_settings', None) # litellm_settings = config.get('litellm_settings', None)
# if litellm_settings: # if litellm_settings:
# for key, value in litellm_settings.items(): # for key, value in litellm_settings.items():
# setattr(litellm, key, value) # setattr(litellm, key, value)
# ## MODEL LIST # ## MODEL LIST
# model_list = config.get('model_list', None) # model_list = config.get('model_list', None)
# if model_list: # if model_list:
# router = litellm.Router(model_list=model_list) # router = litellm.Router(model_list=model_list)
# ## ENVIRONMENT VARIABLES # ## ENVIRONMENT VARIABLES
# environment_variables = config.get('environment_variables', None) # environment_variables = config.get('environment_variables', None)
# if environment_variables: # if environment_variables:
# for key, value in environment_variables.items(): # for key, value in environment_variables.items():
# os.environ[key] = value # os.environ[key] = value
# return router, model_list, server_settings # return router, model_list, server_settings

View file

@ -16,11 +16,11 @@ from openai import (
RateLimitError, RateLimitError,
APIStatusError, APIStatusError,
OpenAIError, OpenAIError,
APIError, APIError,
APITimeoutError, APITimeoutError,
APIConnectionError, APIConnectionError,
APIResponseValidationError, APIResponseValidationError,
UnprocessableEntityError UnprocessableEntityError,
) )
import httpx import httpx
@ -32,11 +32,10 @@ class AuthenticationError(AuthenticationError): # type: ignore
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# raise when invalid models passed, example gpt-8 # raise when invalid models passed, example gpt-8
class NotFoundError(NotFoundError): # type: ignore class NotFoundError(NotFoundError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(self, message, model, llm_provider, response: httpx.Response):
@ -45,9 +44,7 @@ class NotFoundError(NotFoundError): # type: ignore
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -58,23 +55,21 @@ class BadRequestError(BadRequestError): # type: ignore
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 422 self.status_code = 422
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class Timeout(APITimeoutError): # type: ignore class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
self.status_code = 408 self.status_code = 408
@ -86,6 +81,7 @@ class Timeout(APITimeoutError): # type: ignore
request=request request=request
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore class RateLimitError(RateLimitError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 429 self.status_code = 429
@ -93,11 +89,10 @@ class RateLimitError(RateLimitError): # type: ignore
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.modle = model self.modle = model
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors # sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(BadRequestError): # type: ignore class ContextWindowExceededError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(self, message, model, llm_provider, response: httpx.Response):
@ -106,12 +101,13 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
message=self.message, message=self.message,
model=self.model, # type: ignore model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore llm_provider=self.llm_provider, # type: ignore
response=response response=response,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(APIStatusError): # type: ignore class ServiceUnavailableError(APIStatusError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 503 self.status_code = 503
@ -119,50 +115,42 @@ class ServiceUnavailableError(APIStatusError): # type: ignore
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401 # raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(APIError): # type: ignore class APIError(APIError): # type: ignore
def __init__(self, status_code, message, llm_provider, model, request: httpx.Request): def __init__(
self.status_code = status_code self, status_code, message, llm_provider, model, request: httpx.Request
):
self.status_code = status_code
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
super().__init__( super().__init__(self.message, request=request, body=None) # type: ignore
self.message,
request=request, # type: ignore
body=None
)
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIConnectionError(APIConnectionError): # type: ignore class APIConnectionError(APIConnectionError): # type: ignore
def __init__(self, message, llm_provider, model, request: httpx.Request): def __init__(self, message, llm_provider, model, request: httpx.Request):
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.status_code = 500 self.status_code = 500
super().__init__( super().__init__(message=self.message, request=request)
message=self.message,
request=request
)
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIResponseValidationError(APIResponseValidationError): # type: ignore class APIResponseValidationError(APIResponseValidationError): # type: ignore
def __init__(self, message, llm_provider, model): def __init__(self, message, llm_provider, model):
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request) response = httpx.Response(status_code=500, request=request)
super().__init__( super().__init__(response=response, body=None, message=message)
response=response,
body=None,
message=message
)
class OpenAIError(OpenAIError): # type: ignore class OpenAIError(OpenAIError): # type: ignore
def __init__(self, original_exception): def __init__(self, original_exception):
@ -176,6 +164,7 @@ class OpenAIError(OpenAIError): # type: ignore
) )
self.llm_provider = "openai" self.llm_provider = "openai"
class BudgetExceededError(Exception): class BudgetExceededError(Exception):
def __init__(self, current_cost, max_budget): def __init__(self, current_cost, max_budget):
self.current_cost = current_cost self.current_cost = current_cost
@ -183,7 +172,8 @@ class BudgetExceededError(Exception):
message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}" message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
super().__init__(message) super().__init__(message)
## DEPRECATED ##
## DEPRECATED ##
class InvalidRequestError(BadRequestError): # type: ignore class InvalidRequestError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
self.status_code = 400 self.status_code = 400

View file

@ -5,32 +5,33 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal from typing import Literal
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
pass pass
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
pass pass
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
pass pass
def log_stream_event(self, kwargs, response_obj, start_time, end_time): def log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass pass
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
pass pass
def log_failure_event(self, kwargs, response_obj, start_time, end_time): def log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass pass
#### ASYNC #### #### ASYNC ####
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass pass
@ -43,81 +44,87 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass pass
#### CALL HOOKS - proxy only #### #### CALL HOOKS - proxy only ####
""" """
Control the modify incoming / outgoung data before calling the model Control the modify incoming / outgoung data before calling the model
""" """
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal["completion", "embeddings"],
):
pass pass
async def async_post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth): async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
pass pass
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["log_event_type"] = "pre_api_call" kwargs["log_event_type"] = "pre_api_call"
callback_func( callback_func(
kwargs, kwargs,
) )
print_verbose( print_verbose(f"Custom Logger - model call details: {kwargs}")
f"Custom Logger - model call details: {kwargs}" except:
)
except:
traceback.print_exc() traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func): async def async_log_input_event(
try: self, model, messages, kwargs, print_verbose, callback_func
):
try:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["log_event_type"] = "pre_api_call" kwargs["log_event_type"] = "pre_api_call"
await callback_func( await callback_func(
kwargs, kwargs,
) )
print_verbose( print_verbose(f"Custom Logger - model call details: {kwargs}")
f"Custom Logger - model call details: {kwargs}" except:
)
except:
traceback.print_exc() traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func): def log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
):
# Method definition # Method definition
try: try:
kwargs["log_event_type"] = "post_api_call" kwargs["log_event_type"] = "post_api_call"
callback_func( callback_func(
kwargs, # kwargs to func kwargs, # kwargs to func
response_obj, response_obj,
start_time, start_time,
end_time, end_time,
) )
print_verbose( print_verbose(f"Custom Logger - final response object: {response_obj}")
f"Custom Logger - final response object: {response_obj}"
)
except: except:
# traceback.print_exc() # traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass pass
async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func): async def async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
):
# Method definition # Method definition
try: try:
kwargs["log_event_type"] = "post_api_call" kwargs["log_event_type"] = "post_api_call"
await callback_func( await callback_func(
kwargs, # kwargs to func kwargs, # kwargs to func
response_obj, response_obj,
start_time, start_time,
end_time, end_time,
) )
print_verbose( print_verbose(f"Custom Logger - final response object: {response_obj}")
f"Custom Logger - final response object: {response_obj}"
)
except: except:
# traceback.print_exc() # traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass pass

View file

@ -10,19 +10,28 @@ import datetime, subprocess, sys
import litellm, uuid import litellm, uuid
from litellm._logging import print_verbose from litellm._logging import print_verbose
class DyanmoDBLogger: class DyanmoDBLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
# Instance variables # Instance variables
import boto3 import boto3
self.dynamodb = boto3.resource('dynamodb', region_name=os.environ["AWS_REGION_NAME"])
self.dynamodb = boto3.resource(
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
)
if litellm.dynamodb_table_name is None: if litellm.dynamodb_table_name is None:
raise ValueError("LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`") raise ValueError(
"LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`"
)
self.table_name = litellm.dynamodb_table_name self.table_name = litellm.dynamodb_table_name
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose
):
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose) self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
try: try:
print_verbose( print_verbose(
@ -32,7 +41,9 @@ class DyanmoDBLogger:
# construct payload to send to DynamoDB # construct payload to send to DynamoDB
# follows the same params as langfuse.py # follows the same params as langfuse.py
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages") messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {}) optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion") call_type = kwargs.get("call_type", "litellm.completion")
@ -51,7 +62,7 @@ class DyanmoDBLogger:
"messages": messages, "messages": messages,
"response": response_obj, "response": response_obj,
"usage": usage, "usage": usage,
"metadata": metadata "metadata": metadata,
} }
# Ensure everything in the payload is converted to str # Ensure everything in the payload is converted to str
@ -62,9 +73,8 @@ class DyanmoDBLogger:
# non blocking if it can't cast to a str # non blocking if it can't cast to a str
pass pass
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}") print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
# put data in dyanmo DB # put data in dyanmo DB
table = self.dynamodb.Table(self.table_name) table = self.dynamodb.Table(self.table_name)
# Assuming log_data is a dictionary with log information # Assuming log_data is a dictionary with log information
@ -79,4 +89,4 @@ class DyanmoDBLogger:
except: except:
traceback.print_exc() traceback.print_exc()
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}") print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
pass pass

View file

@ -64,7 +64,9 @@ class LangFuseLogger:
# end of processing langfuse ######################## # end of processing langfuse ########################
input = prompt input = prompt
output = response_obj["choices"][0]["message"].json() output = response_obj["choices"][0]["message"].json()
print(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}") print(
f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}"
)
self._log_langfuse_v2( self._log_langfuse_v2(
user_id, user_id,
metadata, metadata,
@ -171,7 +173,6 @@ class LangFuseLogger:
user_id=user_id, user_id=user_id,
) )
trace.generation( trace.generation(
name=metadata.get("generation_name", "litellm-completion"), name=metadata.get("generation_name", "litellm-completion"),
startTime=start_time, startTime=start_time,

View file

@ -8,25 +8,27 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
class LangsmithLogger: class LangsmithLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY") self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
# Method definition # Method definition
# inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb # inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb
metadata = {} metadata = {}
if "litellm_params" in kwargs: if "litellm_params" in kwargs:
metadata = kwargs["litellm_params"].get("metadata", {}) metadata = kwargs["litellm_params"].get("metadata", {})
# set project name and run_name for langsmith logging # set project name and run_name for langsmith logging
# users can pass project_name and run name to litellm.completion() # users can pass project_name and run name to litellm.completion()
# Example: litellm.completion(model, messages, metadata={"project_name": "my-litellm-project", "run_name": "my-langsmith-run"}) # Example: litellm.completion(model, messages, metadata={"project_name": "my-litellm-project", "run_name": "my-langsmith-run"})
# if not set litellm will use default project_name = litellm-completion, run_name = LLMRun # if not set litellm will use default project_name = litellm-completion, run_name = LLMRun
project_name = metadata.get("project_name", "litellm-completion") project_name = metadata.get("project_name", "litellm-completion")
run_name = metadata.get("run_name", "LLMRun") run_name = metadata.get("run_name", "LLMRun")
print_verbose(f"Langsmith Logging - project_name: {project_name}, run_name {run_name}") print_verbose(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
try: try:
print_verbose( print_verbose(
f"Langsmith Logging - Enters logging function for model {kwargs}" f"Langsmith Logging - Enters logging function for model {kwargs}"
@ -34,6 +36,7 @@ class LangsmithLogger:
import requests import requests
import datetime import datetime
from datetime import timezone from datetime import timezone
try: try:
start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat() start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat()
end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat() end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat()
@ -45,7 +48,7 @@ class LangsmithLogger:
new_kwargs = {} new_kwargs = {}
for key in kwargs: for key in kwargs:
value = kwargs[key] value = kwargs[key]
if key == "start_time" or key =="end_time": if key == "start_time" or key == "end_time":
pass pass
elif type(value) != dict: elif type(value) != dict:
new_kwargs[key] = value new_kwargs[key] = value
@ -54,18 +57,14 @@ class LangsmithLogger:
"https://api.smith.langchain.com/runs", "https://api.smith.langchain.com/runs",
json={ json={
"name": run_name, "name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": { "inputs": {**new_kwargs},
**new_kwargs
},
"outputs": response_obj.json(), "outputs": response_obj.json(),
"session_name": project_name, "session_name": project_name,
"start_time": start_time, "start_time": start_time,
"end_time": end_time, "end_time": end_time,
}, },
headers={ headers={"x-api-key": self.langsmith_api_key},
"x-api-key": self.langsmith_api_key
}
) )
print_verbose( print_verbose(
f"Langsmith Layer Logging - final response object: {response_obj}" f"Langsmith Layer Logging - final response object: {response_obj}"

View file

@ -1,6 +1,7 @@
import requests, traceback, json, os import requests, traceback, json, os
import types import types
class LiteDebugger: class LiteDebugger:
user_email = None user_email = None
dashboard_url = None dashboard_url = None
@ -12,9 +13,15 @@ class LiteDebugger:
def validate_environment(self, email): def validate_environment(self, email):
try: try:
self.user_email = (email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")) self.user_email = (
if self.user_email == None: # if users are trying to use_client=True but token not set email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
raise ValueError("litellm.use_client = True but no token or email passed. Please set it in litellm.token") )
if (
self.user_email == None
): # if users are trying to use_client=True but token not set
raise ValueError(
"litellm.use_client = True but no token or email passed. Please set it in litellm.token"
)
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
try: try:
print( print(
@ -42,7 +49,9 @@ class LiteDebugger:
litellm_params, litellm_params,
optional_params, optional_params,
): ):
print_verbose(f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}") print_verbose(
f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}"
)
try: try:
print_verbose( print_verbose(
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}" f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
@ -56,7 +65,11 @@ class LiteDebugger:
updated_litellm_params = remove_key_value(litellm_params, "logger_fn") updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
if call_type == "embedding": if call_type == "embedding":
for message in messages: # assuming the input is a list as required by the embedding function for (
message
) in (
messages
): # assuming the input is a list as required by the embedding function
litellm_data_obj = { litellm_data_obj = {
"model": model, "model": model,
"messages": [{"role": "user", "content": message}], "messages": [{"role": "user", "content": message}],
@ -79,7 +92,9 @@ class LiteDebugger:
elif call_type == "completion": elif call_type == "completion":
litellm_data_obj = { litellm_data_obj = {
"model": model, "model": model,
"messages": messages if isinstance(messages, list) else [{"role": "user", "content": messages}], "messages": messages
if isinstance(messages, list)
else [{"role": "user", "content": messages}],
"end_user": end_user, "end_user": end_user,
"status": "initiated", "status": "initiated",
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
@ -95,20 +110,30 @@ class LiteDebugger:
headers={"content-type": "application/json"}, headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj), data=json.dumps(litellm_data_obj),
) )
print_verbose(f"LiteDebugger: completion api response - {response.text}") print_verbose(
f"LiteDebugger: completion api response - {response.text}"
)
except: except:
print_verbose( print_verbose(
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}" f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
) )
pass pass
def post_call_log_event(self, original_response, litellm_call_id, print_verbose, call_type, stream): def post_call_log_event(
print_verbose(f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}") self, original_response, litellm_call_id, print_verbose, call_type, stream
):
print_verbose(
f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}"
)
try: try:
if call_type == "embedding": if call_type == "embedding":
litellm_data_obj = { litellm_data_obj = {
"status": "received", "status": "received",
"additional_details": {"original_response": str(original_response["data"][0]["embedding"][:5])}, # don't store the entire vector "additional_details": {
"original_response": str(
original_response["data"][0]["embedding"][:5]
)
}, # don't store the entire vector
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"user_email": self.user_email, "user_email": self.user_email,
} }
@ -122,7 +147,11 @@ class LiteDebugger:
elif call_type == "completion" and stream: elif call_type == "completion" and stream:
litellm_data_obj = { litellm_data_obj = {
"status": "received", "status": "received",
"additional_details": {"original_response": "Streamed response" if isinstance(original_response, types.GeneratorType) else original_response}, "additional_details": {
"original_response": "Streamed response"
if isinstance(original_response, types.GeneratorType)
else original_response
},
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"user_email": self.user_email, "user_email": self.user_email,
} }
@ -146,10 +175,12 @@ class LiteDebugger:
end_time, end_time,
litellm_call_id, litellm_call_id,
print_verbose, print_verbose,
call_type, call_type,
stream = False stream=False,
): ):
print_verbose(f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}") print_verbose(
f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}"
)
try: try:
print_verbose( print_verbose(
f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}" f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}"
@ -186,7 +217,7 @@ class LiteDebugger:
data=json.dumps(litellm_data_obj), data=json.dumps(litellm_data_obj),
) )
elif call_type == "completion" and stream == True: elif call_type == "completion" and stream == True:
if len(response_obj["content"]) > 0: # don't log the empty strings if len(response_obj["content"]) > 0: # don't log the empty strings
litellm_data_obj = { litellm_data_obj = {
"response_time": response_time, "response_time": response_time,
"total_cost": total_cost, "total_cost": total_cost,

View file

@ -18,19 +18,17 @@ class PromptLayerLogger:
# Method definition # Method definition
try: try:
new_kwargs = {} new_kwargs = {}
new_kwargs['model'] = kwargs['model'] new_kwargs["model"] = kwargs["model"]
new_kwargs['messages'] = kwargs['messages'] new_kwargs["messages"] = kwargs["messages"]
# add kwargs["optional_params"] to new_kwargs # add kwargs["optional_params"] to new_kwargs
for optional_param in kwargs["optional_params"]: for optional_param in kwargs["optional_params"]:
new_kwargs[optional_param] = kwargs["optional_params"][optional_param] new_kwargs[optional_param] = kwargs["optional_params"][optional_param]
print_verbose( print_verbose(
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}" f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
) )
request_response = requests.post( request_response = requests.post(
"https://api.promptlayer.com/rest/track-request", "https://api.promptlayer.com/rest/track-request",
json={ json={
@ -51,8 +49,8 @@ class PromptLayerLogger:
f"Prompt Layer Logging: success - final response object: {request_response.text}" f"Prompt Layer Logging: success - final response object: {request_response.text}"
) )
response_json = request_response.json() response_json = request_response.json()
if "success" not in request_response.json(): if "success" not in request_response.json():
raise Exception("Promptlayer did not successfully log the response!") raise Exception("Promptlayer did not successfully log the response!")
if "request_id" in response_json: if "request_id" in response_json:
print(kwargs["litellm_params"]["metadata"]) print(kwargs["litellm_params"]["metadata"])
@ -62,10 +60,12 @@ class PromptLayerLogger:
json={ json={
"request_id": response_json["request_id"], "request_id": response_json["request_id"],
"api_key": self.key, "api_key": self.key,
"metadata": kwargs["litellm_params"]["metadata"] "metadata": kwargs["litellm_params"]["metadata"],
}, },
) )
print_verbose(f"Prompt Layer Logging: success - metadata post response object: {response.text}") print_verbose(
f"Prompt Layer Logging: success - metadata post response object: {response.text}"
)
except: except:
print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}") print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")

View file

@ -9,6 +9,7 @@ import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm import litellm
class Supabase: class Supabase:
# Class variables or attributes # Class variables or attributes
supabase_table_name = "request_logs" supabase_table_name = "request_logs"

View file

@ -2,6 +2,7 @@ class TraceloopLogger:
def __init__(self): def __init__(self):
from traceloop.sdk.tracing.tracing import TracerWrapper from traceloop.sdk.tracing.tracing import TracerWrapper
from traceloop.sdk import Traceloop from traceloop.sdk import Traceloop
Traceloop.init(app_name="Litellm-Server", disable_batch=True) Traceloop.init(app_name="Litellm-Server", disable_batch=True)
self.tracer_wrapper = TracerWrapper() self.tracer_wrapper = TracerWrapper()
@ -29,15 +30,18 @@ class TraceloopLogger:
) )
if "stop" in optional_params: if "stop" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_CHAT_STOP_SEQUENCES, optional_params.get("stop") SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
optional_params.get("stop"),
) )
if "frequency_penalty" in optional_params: if "frequency_penalty" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_FREQUENCY_PENALTY, optional_params.get("frequency_penalty") SpanAttributes.LLM_FREQUENCY_PENALTY,
optional_params.get("frequency_penalty"),
) )
if "presence_penalty" in optional_params: if "presence_penalty" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_PRESENCE_PENALTY, optional_params.get("presence_penalty") SpanAttributes.LLM_PRESENCE_PENALTY,
optional_params.get("presence_penalty"),
) )
if "top_p" in optional_params: if "top_p" in optional_params:
span.set_attribute( span.set_attribute(
@ -45,7 +49,10 @@ class TraceloopLogger:
) )
if "tools" in optional_params or "functions" in optional_params: if "tools" in optional_params or "functions" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_REQUEST_FUNCTIONS, optional_params.get("tools", optional_params.get("functions")) SpanAttributes.LLM_REQUEST_FUNCTIONS,
optional_params.get(
"tools", optional_params.get("functions")
),
) )
if "user" in optional_params: if "user" in optional_params:
span.set_attribute( span.set_attribute(
@ -53,7 +60,8 @@ class TraceloopLogger:
) )
if "max_tokens" in optional_params: if "max_tokens" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_REQUEST_MAX_TOKENS, kwargs.get("max_tokens") SpanAttributes.LLM_REQUEST_MAX_TOKENS,
kwargs.get("max_tokens"),
) )
if "temperature" in optional_params: if "temperature" in optional_params:
span.set_attribute( span.set_attribute(

View file

@ -1,4 +1,4 @@
imported_openAIResponse=True imported_openAIResponse = True
try: try:
import io import io
import logging import logging
@ -12,15 +12,12 @@ try:
else: else:
from typing_extensions import Literal, Protocol from typing_extensions import Literal, Protocol
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
K = TypeVar("K", bound=str) K = TypeVar("K", bound=str)
V = TypeVar("V") V = TypeVar("V")
class OpenAIResponse(Protocol[K, V]): # type: ignore
class OpenAIResponse(Protocol[K, V]): # type: ignore
# contains a (known) object attribute # contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"] object: Literal["chat.completion", "edit", "text_completion"]
@ -30,7 +27,6 @@ try:
def get(self, key: K, default: Optional[V] = None) -> Optional[V]: def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
... # pragma: no cover ... # pragma: no cover
class OpenAIRequestResponseResolver: class OpenAIRequestResponseResolver:
def __call__( def __call__(
self, self,
@ -44,7 +40,9 @@ try:
elif response["object"] == "text_completion": elif response["object"] == "text_completion":
return self._resolve_completion(request, response, time_elapsed) return self._resolve_completion(request, response, time_elapsed)
elif response["object"] == "chat.completion": elif response["object"] == "chat.completion":
return self._resolve_chat_completion(request, response, time_elapsed) return self._resolve_chat_completion(
request, response, time_elapsed
)
else: else:
logger.info(f"Unknown OpenAI response object: {response['object']}") logger.info(f"Unknown OpenAI response object: {response['object']}")
except Exception as e: except Exception as e:
@ -113,7 +111,8 @@ try:
"""Resolves the request and response objects for `openai.Completion`.""" """Resolves the request and response objects for `openai.Completion`."""
request_str = f"\n\n**Prompt**: {request['prompt']}\n" request_str = f"\n\n**Prompt**: {request['prompt']}\n"
choices = [ choices = [
f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"] f"\n\n**Completion**: {choice['text']}\n"
for choice in response["choices"]
] ]
return self._request_response_result_to_trace( return self._request_response_result_to_trace(
@ -167,9 +166,9 @@ try:
] ]
trace = self.results_to_trace_tree(request, response, results, time_elapsed) trace = self.results_to_trace_tree(request, response, results, time_elapsed)
return trace return trace
except:
imported_openAIResponse=False
except:
imported_openAIResponse = False
#### What this does #### #### What this does ####
@ -182,29 +181,34 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
class WeightsBiasesLogger: class WeightsBiasesLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
try: try:
import wandb import wandb
except: except:
raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m") raise Exception(
if imported_openAIResponse==False: "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m") )
if imported_openAIResponse == False:
raise Exception(
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
)
self.resolver = OpenAIRequestResponseResolver() self.resolver = OpenAIRequestResponseResolver()
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
# Method definition # Method definition
import wandb import wandb
try: try:
print_verbose( print_verbose(f"W&B Logging - Enters logging function for model {kwargs}")
f"W&B Logging - Enters logging function for model {kwargs}"
)
run = wandb.init() run = wandb.init()
print_verbose(response_obj) print_verbose(response_obj)
trace = self.resolver(kwargs, response_obj, (end_time-start_time).total_seconds()) trace = self.resolver(
kwargs, response_obj, (end_time - start_time).total_seconds()
)
if trace is not None: if trace is not None:
run.log({"trace": trace}) run.log({"trace": trace})

View file

@ -7,76 +7,93 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message
import litellm import litellm
class AI21Error(Exception): class AI21Error(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/") self.request = httpx.Request(
method="POST", url="https://api.ai21.com/studio/v1/"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AI21Config():
class AI21Config:
""" """
Reference: https://docs.ai21.com/reference/j2-complete-ref Reference: https://docs.ai21.com/reference/j2-complete-ref
The class `AI21Config` provides configuration for the AI21's API interface. Below are the parameters: The class `AI21Config` provides configuration for the AI21's API interface. Below are the parameters:
- `numResults` (int32): Number of completions to sample and return. Optional, default is 1. If the temperature is greater than 0 (non-greedy decoding), a value greater than 1 can be meaningful. - `numResults` (int32): Number of completions to sample and return. Optional, default is 1. If the temperature is greater than 0 (non-greedy decoding), a value greater than 1 can be meaningful.
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`. - `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
- `minTokens` (int32): The minimum number of tokens to generate per result. Optional, default is 0. If `stopSequences` are given, they are ignored until `minTokens` are generated. - `minTokens` (int32): The minimum number of tokens to generate per result. Optional, default is 0. If `stopSequences` are given, they are ignored until `minTokens` are generated.
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding. - `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass. - `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional. - `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
- `topKReturn` (int32): Range between 0 to 10, including both. Optional, default is 0. Specifies the top-K alternative tokens to return. A non-zero value includes the string representations and log-probabilities for each of the top-K alternatives at each position. - `topKReturn` (int32): Range between 0 to 10, including both. Optional, default is 0. Specifies the top-K alternative tokens to return. A non-zero value includes the string representations and log-probabilities for each of the top-K alternatives at each position.
- `frequencyPenalty` (object): Placeholder for frequency penalty object. - `frequencyPenalty` (object): Placeholder for frequency penalty object.
- `presencePenalty` (object): Placeholder for presence penalty object. - `presencePenalty` (object): Placeholder for presence penalty object.
- `countPenalty` (object): Placeholder for count penalty object. - `countPenalty` (object): Placeholder for count penalty object.
""" """
numResults: Optional[int]=None
maxTokens: Optional[int]=None
minTokens: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
stopSequences: Optional[list]=None
topKReturn: Optional[int]=None
frequencePenalty: Optional[dict]=None
presencePenalty: Optional[dict]=None
countPenalty: Optional[dict]=None
def __init__(self, numResults: Optional[int] = None
numResults: Optional[int]=None, maxTokens: Optional[int] = None
maxTokens: Optional[int]=None, minTokens: Optional[int] = None
minTokens: Optional[int]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, topP: Optional[float] = None
topP: Optional[float]=None, stopSequences: Optional[list] = None
stopSequences: Optional[list]=None, topKReturn: Optional[int] = None
topKReturn: Optional[int]=None, frequencePenalty: Optional[dict] = None
frequencePenalty: Optional[dict]=None, presencePenalty: Optional[dict] = None
presencePenalty: Optional[dict]=None, countPenalty: Optional[dict] = None
countPenalty: Optional[dict]=None) -> None:
def __init__(
self,
numResults: Optional[int] = None,
maxTokens: Optional[int] = None,
minTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
topKReturn: Optional[int] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
@ -91,6 +108,7 @@ def validate_environment(api_key):
} }
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -110,20 +128,18 @@ def completion(
for message in messages: for message in messages:
if "role" in message: if "role" in message:
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
## Load Config ## Load Config
config = litellm.AI21Config.get_config() config = litellm.AI21Config.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { data = {
@ -134,29 +150,26 @@ def completion(
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
api_base + model + "/complete", headers=headers, data=json.dumps(data) api_base + model + "/complete", headers=headers, data=json.dumps(data)
) )
if response.status_code != 200: if response.status_code != 200:
raise AI21Error( raise AI21Error(status_code=response.status_code, message=response.text)
status_code=response.status_code,
message=response.text
)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response.json() completion_response = response.json()
try: try:
@ -164,18 +177,22 @@ def completion(
for idx, item in enumerate(completion_response["completions"]): for idx, item in enumerate(completion_response["completions"]):
if len(item["data"]["text"]) > 0: if len(item["data"]["text"]) > 0:
message_obj = Message(content=item["data"]["text"]) message_obj = Message(content=item["data"]["text"])
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finishReason"]["reason"], index=idx+1, message=message_obj) choice_obj = Choices(
finish_reason=item["finishReason"]["reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
raise AI21Error(message=traceback.format_exc(), status_code=response.status_code) raise AI21Error(
message=traceback.format_exc(), status_code=response.status_code
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content")) encoding.encode(model_response["choices"][0]["message"].get("content"))
) )
@ -189,6 +206,7 @@ def completion(
} }
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -8,17 +8,21 @@ import litellm
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import httpx import httpx
class AlephAlphaError(Exception): class AlephAlphaError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.aleph-alpha.com/complete") self.request = httpx.Request(
method="POST", url="https://api.aleph-alpha.com/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AlephAlphaConfig():
class AlephAlphaConfig:
""" """
Reference: https://docs.aleph-alpha.com/api/complete/ Reference: https://docs.aleph-alpha.com/api/complete/
@ -42,13 +46,13 @@ class AlephAlphaConfig():
- `repetition_penalties_include_prompt`, `repetition_penalties_include_completion`, `use_multiplicative_presence_penalty`,`use_multiplicative_frequency_penalty`,`use_multiplicative_sequence_penalty` (boolean, nullable; default value: false): Various settings that adjust how the repetition penalties are applied. - `repetition_penalties_include_prompt`, `repetition_penalties_include_completion`, `use_multiplicative_presence_penalty`,`use_multiplicative_frequency_penalty`,`use_multiplicative_sequence_penalty` (boolean, nullable; default value: false): Various settings that adjust how the repetition penalties are applied.
- `penalty_bias` (string, nullable): Text used in addition to the penalized tokens for repetition penalties. - `penalty_bias` (string, nullable): Text used in addition to the penalized tokens for repetition penalties.
- `penalty_exceptions` (string[], nullable): Strings that may be generated without penalty. - `penalty_exceptions` (string[], nullable): Strings that may be generated without penalty.
- `penalty_exceptions_include_stop_sequences` (boolean, nullable; default value: true): Include all stop_sequences in penalty_exceptions. - `penalty_exceptions_include_stop_sequences` (boolean, nullable; default value: true): Include all stop_sequences in penalty_exceptions.
- `best_of` (integer, nullable; default value: 1): The number of completions will be generated on the server side. - `best_of` (integer, nullable; default value: 1): The number of completions will be generated on the server side.
- `n` (integer, nullable; default value: 1): The number of completions to return. - `n` (integer, nullable; default value: 1): The number of completions to return.
@ -68,87 +72,101 @@ class AlephAlphaConfig():
- `completion_bias_inclusion_first_token_only`, `completion_bias_exclusion_first_token_only` (boolean; default value: false): Consider only the first token for the completion_bias_inclusion/exclusion. - `completion_bias_inclusion_first_token_only`, `completion_bias_exclusion_first_token_only` (boolean; default value: false): Consider only the first token for the completion_bias_inclusion/exclusion.
- `contextual_control_threshold` (number, nullable): Control over how similar tokens are controlled. - `contextual_control_threshold` (number, nullable): Control over how similar tokens are controlled.
- `control_log_additive` (boolean; default value: true): Method of applying control to attention scores. - `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
""" """
maximum_tokens: Optional[int]=litellm.max_tokens # aleph alpha requires max tokens
minimum_tokens: Optional[int]=None
echo: Optional[bool]=None
temperature: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[int]=None
presence_penalty: Optional[int]=None
frequency_penalty: Optional[int]=None
sequence_penalty: Optional[int]=None
sequence_penalty_min_length: Optional[int]=None
repetition_penalties_include_prompt: Optional[bool]=None
repetition_penalties_include_completion: Optional[bool]=None
use_multiplicative_presence_penalty: Optional[bool]=None
use_multiplicative_frequency_penalty: Optional[bool]=None
use_multiplicative_sequence_penalty: Optional[bool]=None
penalty_bias: Optional[str]=None
penalty_exceptions_include_stop_sequences: Optional[bool]=None
best_of: Optional[int]=None
n: Optional[int]=None
logit_bias: Optional[dict]=None
log_probs: Optional[int]=None
stop_sequences: Optional[list]=None
tokens: Optional[bool]=None
raw_completion: Optional[bool]=None
disable_optimizations: Optional[bool]=None
completion_bias_inclusion: Optional[list]=None
completion_bias_exclusion: Optional[list]=None
completion_bias_inclusion_first_token_only: Optional[bool]=None
completion_bias_exclusion_first_token_only: Optional[bool]=None
contextual_control_threshold: Optional[int]=None
control_log_additive: Optional[bool]=None
maximum_tokens: Optional[
int
] = litellm.max_tokens # aleph alpha requires max tokens
minimum_tokens: Optional[int] = None
echo: Optional[bool] = None
temperature: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
presence_penalty: Optional[int] = None
frequency_penalty: Optional[int] = None
sequence_penalty: Optional[int] = None
sequence_penalty_min_length: Optional[int] = None
repetition_penalties_include_prompt: Optional[bool] = None
repetition_penalties_include_completion: Optional[bool] = None
use_multiplicative_presence_penalty: Optional[bool] = None
use_multiplicative_frequency_penalty: Optional[bool] = None
use_multiplicative_sequence_penalty: Optional[bool] = None
penalty_bias: Optional[str] = None
penalty_exceptions_include_stop_sequences: Optional[bool] = None
best_of: Optional[int] = None
n: Optional[int] = None
logit_bias: Optional[dict] = None
log_probs: Optional[int] = None
stop_sequences: Optional[list] = None
tokens: Optional[bool] = None
raw_completion: Optional[bool] = None
disable_optimizations: Optional[bool] = None
completion_bias_inclusion: Optional[list] = None
completion_bias_exclusion: Optional[list] = None
completion_bias_inclusion_first_token_only: Optional[bool] = None
completion_bias_exclusion_first_token_only: Optional[bool] = None
contextual_control_threshold: Optional[int] = None
control_log_additive: Optional[bool] = None
def __init__(self, def __init__(
maximum_tokens: Optional[int]=None, self,
minimum_tokens: Optional[int]=None, maximum_tokens: Optional[int] = None,
echo: Optional[bool]=None, minimum_tokens: Optional[int] = None,
temperature: Optional[int]=None, echo: Optional[bool] = None,
top_k: Optional[int]=None, temperature: Optional[int] = None,
top_p: Optional[int]=None, top_k: Optional[int] = None,
presence_penalty: Optional[int]=None, top_p: Optional[int] = None,
frequency_penalty: Optional[int]=None, presence_penalty: Optional[int] = None,
sequence_penalty: Optional[int]=None, frequency_penalty: Optional[int] = None,
sequence_penalty_min_length: Optional[int]=None, sequence_penalty: Optional[int] = None,
repetition_penalties_include_prompt: Optional[bool]=None, sequence_penalty_min_length: Optional[int] = None,
repetition_penalties_include_completion: Optional[bool]=None, repetition_penalties_include_prompt: Optional[bool] = None,
use_multiplicative_presence_penalty: Optional[bool]=None, repetition_penalties_include_completion: Optional[bool] = None,
use_multiplicative_frequency_penalty: Optional[bool]=None, use_multiplicative_presence_penalty: Optional[bool] = None,
use_multiplicative_sequence_penalty: Optional[bool]=None, use_multiplicative_frequency_penalty: Optional[bool] = None,
penalty_bias: Optional[str]=None, use_multiplicative_sequence_penalty: Optional[bool] = None,
penalty_exceptions_include_stop_sequences: Optional[bool]=None, penalty_bias: Optional[str] = None,
best_of: Optional[int]=None, penalty_exceptions_include_stop_sequences: Optional[bool] = None,
n: Optional[int]=None, best_of: Optional[int] = None,
logit_bias: Optional[dict]=None, n: Optional[int] = None,
log_probs: Optional[int]=None, logit_bias: Optional[dict] = None,
stop_sequences: Optional[list]=None, log_probs: Optional[int] = None,
tokens: Optional[bool]=None, stop_sequences: Optional[list] = None,
raw_completion: Optional[bool]=None, tokens: Optional[bool] = None,
disable_optimizations: Optional[bool]=None, raw_completion: Optional[bool] = None,
completion_bias_inclusion: Optional[list]=None, disable_optimizations: Optional[bool] = None,
completion_bias_exclusion: Optional[list]=None, completion_bias_inclusion: Optional[list] = None,
completion_bias_inclusion_first_token_only: Optional[bool]=None, completion_bias_exclusion: Optional[list] = None,
completion_bias_exclusion_first_token_only: Optional[bool]=None, completion_bias_inclusion_first_token_only: Optional[bool] = None,
contextual_control_threshold: Optional[int]=None, completion_bias_exclusion_first_token_only: Optional[bool] = None,
control_log_additive: Optional[bool]=None) -> None: contextual_control_threshold: Optional[int] = None,
control_log_additive: Optional[bool] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
@ -160,6 +178,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -177,9 +196,11 @@ def completion(
headers = validate_environment(api_key) headers = validate_environment(api_key)
## Load Config ## Load Config
config = litellm.AlephAlphaConfig.get_config() config = litellm.AlephAlphaConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
completion_url = api_base completion_url = api_base
@ -188,21 +209,17 @@ def completion(
if "control" in model: # follow the ###Instruction / ###Response format if "control" in model: # follow the ###Instruction / ###Response format
for idx, message in enumerate(messages): for idx, message in enumerate(messages):
if "role" in message: if "role" in message:
if idx == 0: # set first message as instruction (required), let later user messages be input if (
idx == 0
): # set first message as instruction (required), let later user messages be input
prompt += f"###Instruction: {message['content']}" prompt += f"###Instruction: {message['content']}"
else: else:
if message["role"] == "system": if message["role"] == "system":
prompt += ( prompt += f"###Instruction: {message['content']}"
f"###Instruction: {message['content']}"
)
elif message["role"] == "user": elif message["role"] == "user":
prompt += ( prompt += f"###Input: {message['content']}"
f"###Input: {message['content']}"
)
else: else:
prompt += ( prompt += f"###Response: {message['content']}"
f"###Response: {message['content']}"
)
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
else: else:
@ -215,24 +232,27 @@ def completion(
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response.json() completion_response = response.json()
@ -247,18 +267,23 @@ def completion(
for idx, item in enumerate(completion_response["completions"]): for idx, item in enumerate(completion_response["completions"]):
if len(item["completion"]) > 0: if len(item["completion"]) > 0:
message_obj = Message(content=item["completion"]) message_obj = Message(content=item["completion"])
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except: except:
raise AlephAlphaError(message=json.dumps(completion_response), status_code=response.status_code) raise AlephAlphaError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"]) encoding.encode(model_response["choices"][0]["message"]["content"])
) )
@ -268,11 +293,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -5,56 +5,76 @@ import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx import httpx
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: " AI_PROMPT = "\n\nAssistant: "
class AnthropicError(Exception): class AnthropicError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.anthropic.com/v1/complete") self.request = httpx.Request(
method="POST", url="https://api.anthropic.com/v1/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AnthropicConfig():
class AnthropicConfig:
""" """
Reference: https://docs.anthropic.com/claude/reference/complete_post Reference: https://docs.anthropic.com/claude/reference/complete_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
""" """
max_tokens_to_sample: Optional[int]=litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list]=None
temperature: Optional[int]=None
top_p: Optional[int]=None
top_k: Optional[int]=None
metadata: Optional[dict]=None
def __init__(self, max_tokens_to_sample: Optional[
max_tokens_to_sample: Optional[int]=256, # anthropic requires a default int
stop_sequences: Optional[list]=None, ] = litellm.max_tokens # anthropic requires a default
temperature: Optional[int]=None, stop_sequences: Optional[list] = None
top_p: Optional[int]=None, temperature: Optional[int] = None
top_k: Optional[int]=None, top_p: Optional[int] = None
metadata: Optional[dict]=None) -> None: top_k: Optional[int] = None
metadata: Optional[dict] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# makes headers for API call # makes headers for API call
@ -71,6 +91,7 @@ def validate_environment(api_key):
} }
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -87,21 +108,25 @@ def completion(
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key)
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic") prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## Load Config ## Load Config
config = litellm.AnthropicConfig.get_config() config = litellm.AnthropicConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { data = {
@ -116,7 +141,7 @@ def completion(
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "api_base": api_base}, additional_args={"complete_input_dict": data, "api_base": api_base},
) )
## COMPLETION CALL ## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post( response = requests.post(
@ -125,18 +150,20 @@ def completion(
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream"], stream=optional_params["stream"],
) )
if response.status_code != 200: if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text) raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines() return response.iter_lines()
else: else:
response = requests.post( response = requests.post(api_base, headers=headers, data=json.dumps(data))
api_base, headers=headers, data=json.dumps(data)
)
if response.status_code != 200: if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text) raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
@ -159,9 +186,9 @@ def completion(
) )
else: else:
if len(completion_response["completion"]) > 0: if len(completion_response["completion"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response[ model_response["choices"][0]["message"][
"completion" "content"
] ] = completion_response["completion"]
model_response.choices[0].finish_reason = completion_response["stop_reason"] model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE ## CALCULATING USAGE
@ -177,11 +204,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -1,7 +1,13 @@
from typing import Optional, Union, Any from typing import Optional, Union, Any
import types, requests import types, requests
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
)
from typing import Callable, Optional from typing import Callable, Optional
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm, json import litellm, json
@ -9,8 +15,15 @@ import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI from openai import AzureOpenAI, AsyncAzureOpenAI
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
if request: if request:
@ -20,11 +33,14 @@ class AzureOpenAIError(Exception):
if response: if response:
self.response = response self.response = response
else: else:
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AzureOpenAIConfig(OpenAIConfig): class AzureOpenAIConfig(OpenAIConfig):
""" """
Reference: https://platform.openai.com/docs/api-reference/chat/create Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -49,33 +65,37 @@ class AzureOpenAIConfig(OpenAIConfig):
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
""" """
def __init__(self, def __init__(
frequency_penalty: Optional[int] = None, self,
function_call: Optional[Union[str, dict]]= None, frequency_penalty: Optional[int] = None,
functions: Optional[list]= None, function_call: Optional[Union[str, dict]] = None,
logit_bias: Optional[dict]= None, functions: Optional[list] = None,
max_tokens: Optional[int]= None, logit_bias: Optional[dict] = None,
n: Optional[int]= None, max_tokens: Optional[int] = None,
presence_penalty: Optional[int]= None, n: Optional[int] = None,
stop: Optional[Union[str,list]]=None, presence_penalty: Optional[int] = None,
temperature: Optional[int]= None, stop: Optional[Union[str, list]] = None,
top_p: Optional[int]= None) -> None: temperature: Optional[int] = None,
super().__init__(frequency_penalty, top_p: Optional[int] = None,
function_call, ) -> None:
functions, super().__init__(
logit_bias, frequency_penalty,
max_tokens, function_call,
n, functions,
presence_penalty, logit_bias,
stop, max_tokens,
temperature, n,
top_p) presence_penalty,
stop,
temperature,
top_p,
)
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -89,49 +109,51 @@ class AzureChatCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {azure_ad_token}" headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers return headers
def completion(self, def completion(
model: str, self,
messages: list, model: str,
model_response: ModelResponse, messages: list,
api_key: str, model_response: ModelResponse,
api_base: str, api_key: str,
api_version: str, api_base: str,
api_type: str, api_version: str,
azure_ad_token: str, api_type: str,
print_verbose: Callable, azure_ad_token: str,
timeout, print_verbose: Callable,
logging_obj, timeout,
optional_params, logging_obj,
litellm_params, optional_params,
logger_fn, litellm_params,
acompletion: bool = False, logger_fn,
headers: Optional[dict]=None, acompletion: bool = False,
client = None, headers: Optional[dict] = None,
): client=None,
):
super().completion() super().completion()
exception_mapping_worked = False exception_mapping_worked = False
try: try:
if model is None or messages is None: if model is None or messages is None:
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages") raise AzureOpenAIError(
status_code=422, message=f"Missing model or messages"
)
max_retries = optional_params.pop("max_retries", 2) max_retries = optional_params.pop("max_retries", 2)
### CHECK IF CLOUDFLARE AI GATEWAY ### ### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url ### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base: if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name ## build base url - assume api base includes resource name
if client is None: if client is None:
if not api_base.endswith("/"): if not api_base.endswith("/"):
api_base += "/" api_base += "/"
api_base += f"{model}" api_base += f"{model}"
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
"base_url": f"{api_base}", "base_url": f"{api_base}",
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -142,26 +164,53 @@ class AzureChatCompletion(BaseLLM):
client = AsyncAzureOpenAI(**azure_client_params) client = AsyncAzureOpenAI(**azure_client_params)
else: else:
client = AzureOpenAI(**azure_client_params) client = AzureOpenAI(**azure_client_params)
data = {"model": None, "messages": messages, **optional_params}
else:
data = { data = {
"model": None, "model": model, # type: ignore
"messages": messages, "messages": messages,
**optional_params **optional_params,
} }
else:
data = { if acompletion is True:
"model": model, # type: ignore
"messages": messages,
**optional_params
}
if acompletion is True:
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client) return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else: else:
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, client=client, logging_obj=logging_obj) return self.acompletion(
api_base=api_base,
data=data,
model_response=model_response,
api_key=api_key,
api_version=api_version,
model=model,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
logging_obj=logging_obj,
)
elif "stream" in optional_params and optional_params["stream"] == True: elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client) return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else: else:
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -169,16 +218,18 @@ class AzureChatCompletion(BaseLLM):
api_key=api_key, api_key=api_key,
additional_args={ additional_args={
"headers": { "headers": {
"api_key": api_key, "api_key": api_key,
"azure_ad_token": azure_ad_token "azure_ad_token": azure_ad_token,
}, },
"api_version": api_version, "api_version": api_version,
"api_base": api_base, "api_base": api_base,
"complete_input_dict": data, "complete_input_dict": data,
}, },
) )
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
@ -186,7 +237,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -196,7 +247,7 @@ class AzureChatCompletion(BaseLLM):
azure_client = AzureOpenAI(**azure_client_params) azure_client = AzureOpenAI(**azure_client_params)
else: else:
azure_client = client azure_client = client
response = azure_client.chat.completions.create(**data) # type: ignore response = azure_client.chat.completions.create(**data) # type: ignore
stringified_response = response.model_dump_json() stringified_response = response.model_dump_json()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -209,30 +260,36 @@ class AzureChatCompletion(BaseLLM):
"api_base": api_base, "api_base": api_base,
}, },
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) return convert_to_model_response_object(
except AzureOpenAIError as e: response_object=json.loads(stringified_response),
model_response_object=model_response,
)
except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
raise e raise e
async def acompletion(self, async def acompletion(
api_key: str, self,
api_version: str, api_key: str,
model: str, api_version: str,
api_base: str, model: str,
data: dict, api_base: str,
timeout: Any, data: dict,
model_response: ModelResponse, timeout: Any,
azure_ad_token: Optional[str]=None, model_response: ModelResponse,
client = None, # this is the AsyncAzureOpenAI azure_ad_token: Optional[str] = None,
logging_obj=None, client=None, # this is the AsyncAzureOpenAI
): logging_obj=None,
response = None ):
try: response = None
try:
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
@ -240,7 +297,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -252,35 +309,46 @@ class AzureChatCompletion(BaseLLM):
azure_client = client azure_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data['messages'], input=data["messages"],
api_key=azure_client.api_key, api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
) )
response = await azure_client.chat.completions.create(**data) response = await azure_client.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) return convert_to_model_response_object(
except AzureOpenAIError as e: response_object=json.loads(response.model_dump_json()),
model_response_object=model_response,
)
except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if hasattr(e, "status_code"): if hasattr(e, "status_code"):
raise e raise e
else: else:
raise AzureOpenAIError(status_code=500, message=str(e)) raise AzureOpenAIError(status_code=500, message=str(e))
def streaming(self, def streaming(
logging_obj, self,
api_base: str, logging_obj,
api_key: str, api_base: str,
api_version: str, api_key: str,
data: dict, api_version: str,
model: str, data: dict,
timeout: Any, model: str,
azure_ad_token: Optional[str]=None, timeout: Any,
client=None, azure_ad_token: Optional[str] = None,
): client=None,
):
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
@ -288,7 +356,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -300,25 +368,36 @@ class AzureChatCompletion(BaseLLM):
azure_client = client azure_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data['messages'], input=data["messages"],
api_key=azure_client.api_key, api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
) )
response = azure_client.chat.completions.create(**data) response = azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
return streamwrapper return streamwrapper
async def async_streaming(self, async def async_streaming(
logging_obj, self,
api_base: str, logging_obj,
api_key: str, api_base: str,
api_version: str, api_key: str,
data: dict, api_version: str,
model: str, data: dict,
timeout: Any, model: str,
azure_ad_token: Optional[str]=None, timeout: Any,
client = None, azure_ad_token: Optional[str] = None,
): client=None,
):
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
@ -326,39 +405,49 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": data.pop("max_retries", 2), "max_retries": data.pop("max_retries", 2),
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params) azure_client = AsyncAzureOpenAI(**azure_client_params)
else: else:
azure_client = client azure_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data['messages'], input=data["messages"],
api_key=azure_client.api_key, api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
) )
response = await azure_client.chat.completions.create(**data) response = await azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
async def aembedding( async def aembedding(
self, self,
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
azure_client_params: dict, azure_client_params: dict,
api_key: str, api_key: str,
input: list, input: list,
client=None, client=None,
logging_obj=None logging_obj=None,
): ):
response = None response = None
try: try:
if client is None: if client is None:
openai_aclient = AsyncAzureOpenAI(**azure_client_params) openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else: else:
@ -367,50 +456,53 @@ class AzureChatCompletion(BaseLLM):
stringified_response = response.model_dump_json() stringified_response = response.model_dump_json()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
response_type="embedding",
)
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=str(e), original_response=str(e),
) )
raise e raise e
def embedding(self, def embedding(
model: str, self,
input: list, model: str,
api_key: str, input: list,
api_base: str, api_key: str,
api_version: str, api_base: str,
timeout: float, api_version: str,
logging_obj=None, timeout: float,
model_response=None, logging_obj=None,
optional_params=None, model_response=None,
azure_ad_token: Optional[str]=None, optional_params=None,
client = None, azure_ad_token: Optional[str] = None,
aembedding=None, client=None,
): aembedding=None,
):
super().embedding() super().embedding()
exception_mapping_worked = False exception_mapping_worked = False
if self._client_session is None: if self._client_session is None:
self._client_session = self.create_client_session() self._client_session = self.create_client_session()
try: try:
data = { data = {"model": model, "input": input, **optional_params}
"model": model,
"input": input,
**optional_params
}
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
@ -418,7 +510,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -427,119 +519,130 @@ class AzureChatCompletion(BaseLLM):
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={ additional_args={
"complete_input_dict": data, "complete_input_dict": data,
"headers": { "headers": {"api_key": api_key, "azure_ad_token": azure_ad_token},
"api_key": api_key, },
"azure_ad_token": azure_ad_token )
}
},
)
if aembedding == True: if aembedding == True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params) response = self.aembedding(
data=data,
input=input,
logging_obj=logging_obj,
api_key=api_key,
model_response=model_response,
azure_client_params=azure_client_params,
)
return response return response
if client is None: if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else: else:
azure_client = client azure_client = client
## COMPLETION CALL ## COMPLETION CALL
response = azure_client.embeddings.create(**data) # type: ignore response = azure_client.embeddings.create(**data) # type: ignore
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "api_base": api_base}, additional_args={"complete_input_dict": data, "api_base": api_base},
original_response=response, original_response=response,
) )
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore except AzureOpenAIError as e:
except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if exception_mapping_worked: if exception_mapping_worked:
raise e raise e
else: else:
import traceback import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def aimage_generation( async def aimage_generation(
self, self,
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
azure_client_params: dict, azure_client_params: dict,
api_key: str, api_key: str,
input: list, input: list,
client=None, client=None,
logging_obj=None logging_obj=None,
): ):
response = None response = None
try: try:
if client is None: if client is None:
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) client_session = litellm.aclient_session or httpx.AsyncClient(
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params) transport=AsyncCustomHTTPTransport(),
)
openai_aclient = AsyncAzureOpenAI(
http_client=client_session, **azure_client_params
)
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.images.generate(**data) response = await openai_aclient.images.generate(**data)
stringified_response = response.model_dump_json() stringified_response = response.model_dump_json()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation") return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
response_type="image_generation",
)
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=str(e), original_response=str(e),
) )
raise e raise e
def image_generation(self, def image_generation(
prompt: str, self,
timeout: float, prompt: str,
model: Optional[str]=None, timeout: float,
api_key: Optional[str] = None, model: Optional[str] = None,
api_base: Optional[str] = None, api_key: Optional[str] = None,
api_version: Optional[str] = None, api_base: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None, api_version: Optional[str] = None,
azure_ad_token: Optional[str]=None, model_response: Optional[litellm.utils.ImageResponse] = None,
logging_obj=None, azure_ad_token: Optional[str] = None,
optional_params=None, logging_obj=None,
client=None, optional_params=None,
aimg_generation=None, client=None,
): aimg_generation=None,
):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
if model and len(model) > 0: if model and len(model) > 0:
model = model model = model
else: else:
model = None model = None
data = { data = {"model": model, "prompt": prompt, **optional_params}
"model": model,
"prompt": prompt,
**optional_params
}
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
"azure_endpoint": api_base, "azure_endpoint": api_base,
"azure_deployment": model, "azure_deployment": model,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -547,39 +650,47 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True: if aimg_generation == True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params) # type: ignore
return response return response
if client is None: if client is None:
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),) client_session = litellm.client_session or httpx.Client(
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore transport=CustomHTTPTransport(),
)
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
else: else:
azure_client = client azure_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=azure_client.api_key, api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data}, additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": False,
"complete_input_dict": data,
},
) )
## COMPLETION CALL ## COMPLETION CALL
response = azure_client.images.generate(**data) # type: ignore response = azure_client.images.generate(**data) # type: ignore
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=response, original_response=response,
) )
# return response # return response
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if exception_mapping_worked: if exception_mapping_worked:
raise e raise e
else: else:
import traceback import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())

View file

@ -1,47 +1,45 @@
## This is a template base class to be used for adding new LLM providers via API calls ## This is a template base class to be used for adding new LLM providers via API calls
import litellm import litellm
import httpx, certifi, ssl import httpx, certifi, ssl
from typing import Optional from typing import Optional
class BaseLLM: class BaseLLM:
_client_session: Optional[httpx.Client] = None _client_session: Optional[httpx.Client] = None
def create_client_session(self): def create_client_session(self):
if litellm.client_session: if litellm.client_session:
_client_session = litellm.client_session _client_session = litellm.client_session
else: else:
_client_session = httpx.Client() _client_session = httpx.Client()
return _client_session return _client_session
def create_aclient_session(self): def create_aclient_session(self):
if litellm.aclient_session: if litellm.aclient_session:
_aclient_session = litellm.aclient_session _aclient_session = litellm.aclient_session
else: else:
_aclient_session = httpx.AsyncClient() _aclient_session = httpx.AsyncClient()
return _aclient_session return _aclient_session
def __exit__(self): def __exit__(self):
if hasattr(self, '_client_session'): if hasattr(self, "_client_session"):
self._client_session.close() self._client_session.close()
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if hasattr(self, '_aclient_session'): if hasattr(self, "_aclient_session"):
await self._aclient_session.aclose() await self._aclient_session.aclose()
def validate_environment(self): # set up the environment required to run the model def validate_environment(self): # set up the environment required to run the model
pass pass
def completion( def completion(
self, self, *args, **kwargs
*args,
**kwargs
): # logic for parsing in - calling - parsing out model completion calls ): # logic for parsing in - calling - parsing out model completion calls
pass pass
def embedding( def embedding(
self, self, *args, **kwargs
*args,
**kwargs
): # logic for parsing in - calling - parsing out model embedding calls ): # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -6,6 +6,7 @@ import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
class BasetenError(Exception): class BasetenError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -14,6 +15,7 @@ class BasetenError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def validate_environment(api_key): def validate_environment(api_key):
headers = { headers = {
"accept": "application/json", "accept": "application/json",
@ -23,6 +25,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Api-Key {api_key}" headers["Authorization"] = f"Api-Key {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -52,32 +55,38 @@ def completion(
"inputs": prompt, "inputs": prompt,
"prompt": prompt, "prompt": prompt,
"parameters": optional_params, "parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False "stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
} }
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url_fragment_1 + model + completion_url_fragment_2, completion_url_fragment_1 + model + completion_url_fragment_2,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=True if "stream" in optional_params and optional_params["stream"] == True else False stream=True
if "stream" in optional_params and optional_params["stream"] == True
else False,
) )
if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True): if "text/event-stream" in response.headers["Content-Type"] or (
"stream" in optional_params and optional_params["stream"] == True
):
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response.json() completion_response = response.json()
@ -91,9 +100,7 @@ def completion(
if ( if (
isinstance(completion_response["model_output"], dict) isinstance(completion_response["model_output"], dict)
and "data" in completion_response["model_output"] and "data" in completion_response["model_output"]
and isinstance( and isinstance(completion_response["model_output"]["data"], list)
completion_response["model_output"]["data"], list
)
): ):
model_response["choices"][0]["message"][ model_response["choices"][0]["message"][
"content" "content"
@ -112,12 +119,19 @@ def completion(
if "generated_text" not in completion_response: if "generated_text" not in completion_response:
raise BasetenError( raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}", message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code status_code=response.status_code,
) )
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] model_response["choices"][0]["message"][
## GETTING LOGPROBS "content"
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: ] = completion_response[0]["generated_text"]
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] ## GETTING LOGPROBS
if (
"details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0 sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]: for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
@ -125,7 +139,7 @@ def completion(
else: else:
raise BasetenError( raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}", message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code status_code=response.status_code,
) )
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
@ -139,11 +153,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, get_secret, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx import httpx
class BedrockError(Exception): class BedrockError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock") self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AmazonTitanConfig():
class AmazonTitanConfig:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1 Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
@ -29,29 +33,44 @@ class AmazonTitanConfig():
- `temperature` (float) temperature for model, - `temperature` (float) temperature for model,
- `topP` (int) top p for model - `topP` (int) top p for model
""" """
maxTokenCount: Optional[int]=None
stopSequences: Optional[list]=None
temperature: Optional[float]=None
topP: Optional[int]=None
def __init__(self, maxTokenCount: Optional[int] = None
maxTokenCount: Optional[int]=None, stopSequences: Optional[list] = None
stopSequences: Optional[list]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, topP: Optional[int] = None
topP: Optional[int]=None) -> None:
def __init__(
self,
maxTokenCount: Optional[int] = None,
stopSequences: Optional[list] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAnthropicConfig():
class AmazonAnthropicConfig:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
@ -64,33 +83,48 @@ class AmazonAnthropicConfig():
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"], - `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" - `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
""" """
max_tokens_to_sample: Optional[int]=litellm.max_tokens
stop_sequences: Optional[list]=None
temperature: Optional[float]=None
top_k: Optional[int]=None
top_p: Optional[int]=None
anthropic_version: Optional[str]=None
def __init__(self, max_tokens_to_sample: Optional[int] = litellm.max_tokens
max_tokens_to_sample: Optional[int]=None, stop_sequences: Optional[list] = None
stop_sequences: Optional[list]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, top_k: Optional[int] = None
top_k: Optional[int]=None, top_p: Optional[int] = None
top_p: Optional[int]=None, anthropic_version: Optional[str] = None
anthropic_version: Optional[str]=None) -> None:
def __init__(
self,
max_tokens_to_sample: Optional[int] = None,
stop_sequences: Optional[list] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonCohereConfig():
class AmazonCohereConfig:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
@ -100,79 +134,110 @@ class AmazonCohereConfig():
- `temperature` (float) model temperature, - `temperature` (float) model temperature,
- `return_likelihood` (string) n/a - `return_likelihood` (string) n/a
""" """
max_tokens: Optional[int]=None
temperature: Optional[float]=None
return_likelihood: Optional[str]=None
def __init__(self, max_tokens: Optional[int] = None
max_tokens: Optional[int]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, return_likelihood: Optional[str] = None
return_likelihood: Optional[str]=None) -> None:
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
return_likelihood: Optional[str] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAI21Config():
class AmazonAI21Config:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
Supported Params for the Amazon / AI21 models: Supported Params for the Amazon / AI21 models:
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`. - `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding. - `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass. - `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional. - `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
- `frequencyPenalty` (object): Placeholder for frequency penalty object. - `frequencyPenalty` (object): Placeholder for frequency penalty object.
- `presencePenalty` (object): Placeholder for presence penalty object. - `presencePenalty` (object): Placeholder for presence penalty object.
- `countPenalty` (object): Placeholder for count penalty object. - `countPenalty` (object): Placeholder for count penalty object.
""" """
maxTokens: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
stopSequences: Optional[list]=None
frequencePenalty: Optional[dict]=None
presencePenalty: Optional[dict]=None
countPenalty: Optional[dict]=None
def __init__(self, maxTokens: Optional[int] = None
maxTokens: Optional[int]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, topP: Optional[float] = None
topP: Optional[float]=None, stopSequences: Optional[list] = None
stopSequences: Optional[list]=None, frequencePenalty: Optional[dict] = None
frequencePenalty: Optional[dict]=None, presencePenalty: Optional[dict] = None
presencePenalty: Optional[dict]=None, countPenalty: Optional[dict] = None
countPenalty: Optional[dict]=None) -> None:
def __init__(
self,
maxTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: " AI_PROMPT = "\n\nAssistant: "
class AmazonLlamaConfig():
class AmazonLlamaConfig:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1 Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
@ -182,48 +247,72 @@ class AmazonLlamaConfig():
- `temperature` (float) temperature for model, - `temperature` (float) temperature for model,
- `top_p` (float) top p for model - `top_p` (float) top p for model
""" """
max_gen_len: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
def __init__(self, max_gen_len: Optional[int] = None
maxTokenCount: Optional[int]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, topP: Optional[float] = None
topP: Optional[int]=None) -> None:
def __init__(
self,
maxTokenCount: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def init_bedrock_client( def init_bedrock_client(
region_name = None, region_name=None,
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None, aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] =None, aws_region_name: Optional[str] = None,
aws_bedrock_runtime_endpoint: Optional[str]=None, aws_bedrock_runtime_endpoint: Optional[str] = None,
): ):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None) standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in ## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check # Define the list of parameters to check
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint] params_to_check = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
]
# Iterate over parameters and update if needed # Iterate over parameters and update if needed
for i, param in enumerate(params_to_check): for i, param in enumerate(params_to_check):
if param and param.startswith('os.environ/'): if param and param.startswith("os.environ/"):
params_to_check[i] = get_secret(param) params_to_check[i] = get_secret(param)
# Assign updated values back to parameters # Assign updated values back to parameters
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check (
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
) = params_to_check
if region_name: if region_name:
pass pass
elif aws_region_name: elif aws_region_name:
@ -233,7 +322,10 @@ def init_bedrock_client(
elif standard_aws_region_name: elif standard_aws_region_name:
region_name = standard_aws_region_name region_name = standard_aws_region_name
else: else:
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401) raise BedrockError(
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
status_code=401,
)
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client # check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
@ -242,9 +334,10 @@ def init_bedrock_client(
elif env_aws_bedrock_runtime_endpoint: elif env_aws_bedrock_runtime_endpoint:
endpoint_url = env_aws_bedrock_runtime_endpoint endpoint_url = env_aws_bedrock_runtime_endpoint
else: else:
endpoint_url = f'https://bedrock-runtime.{region_name}.amazonaws.com' endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
import boto3 import boto3
if aws_access_key_id != None: if aws_access_key_id != None:
# uses auth params passed to completion # uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
@ -257,7 +350,7 @@ def init_bedrock_client(
endpoint_url=endpoint_url, endpoint_url=endpoint_url,
) )
else: else:
# aws_access_key_id is None, assume user is trying to auth using env variables # aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automatically reads env variables # boto3 automatically reads env variables
client = boto3.client( client = boto3.client(
@ -276,25 +369,23 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic") prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
else: else:
prompt = "" prompt = ""
for message in messages: for message in messages:
if "role" in message: if "role" in message:
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
return prompt return prompt
@ -309,17 +400,18 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
# set os.environ['AWS_REGION_NAME'] = <your-region_name> # set os.environ['AWS_REGION_NAME'] = <your-region_name>
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
custom_prompt_dict: dict, custom_prompt_dict: dict,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
@ -327,7 +419,9 @@ def completion(
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None) aws_region_name = optional_params.pop("aws_region_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None) aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop( client = optional_params.pop(
@ -343,67 +437,71 @@ def completion(
model = model model = model
provider = model.split(".")[0] provider = model.split(".")[0]
prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) prompt = convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False) stream = inference_params.pop("stream", False)
if provider == "anthropic": if provider == "anthropic":
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config() config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
data = json.dumps({ data = json.dumps({"prompt": prompt, **inference_params})
"prompt": prompt,
**inference_params
})
elif provider == "ai21": elif provider == "ai21":
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config() config = litellm.AmazonAI21Config.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
data = json.dumps({ data = json.dumps({"prompt": prompt, **inference_params})
"prompt": prompt,
**inference_params
})
elif provider == "cohere": elif provider == "cohere":
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config() config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:
inference_params["stream"] = True # cohere requires stream = True in inference params inference_params[
data = json.dumps({ "stream"
"prompt": prompt, ] = True # cohere requires stream = True in inference params
**inference_params data = json.dumps({"prompt": prompt, **inference_params})
})
elif provider == "meta": elif provider == "meta":
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config() config = litellm.AmazonLlamaConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
data = json.dumps({ data = json.dumps({"prompt": prompt, **inference_params})
"prompt": prompt,
**inference_params
})
elif provider == "amazon": # amazon titan elif provider == "amazon": # amazon titan
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonTitanConfig.get_config() config = litellm.AmazonTitanConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
data = json.dumps({ data = json.dumps(
"inputText": prompt, {
"textGenerationConfig": inference_params, "inputText": prompt,
}) "textGenerationConfig": inference_params,
}
)
## COMPLETION CALL ## COMPLETION CALL
accept = 'application/json' accept = "application/json"
contentType = 'application/json' contentType = "application/json"
if stream == True: if stream == True:
if provider == "ai21": if provider == "ai21":
## LOGGING ## LOGGING
@ -418,17 +516,17 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str}, additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
) )
response = client.invoke_model( response = client.invoke_model(
body=data, body=data, modelId=model, accept=accept, contentType=contentType
modelId=model,
accept=accept,
contentType=contentType
) )
response = response.get('body').read() response = response.get("body").read()
return response return response
else: else:
## LOGGING ## LOGGING
@ -441,20 +539,20 @@ def completion(
) )
""" """
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str}, additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
) )
response = client.invoke_model_with_response_stream( response = client.invoke_model_with_response_stream(
body=data, body=data, modelId=model, accept=accept, contentType=contentType
modelId=model,
accept=accept,
contentType=contentType
) )
response = response.get('body') response = response.get("body")
return response return response
try: try:
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
response = client.invoke_model( response = client.invoke_model(
@ -465,20 +563,20 @@ def completion(
) )
""" """
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str}, additional_args={
) "complete_input_dict": data,
response = client.invoke_model( "request_str": request_str,
body=data, },
modelId=model,
accept=accept,
contentType=contentType
) )
except Exception as e: response = client.invoke_model(
body=data, modelId=model, accept=accept, contentType=contentType
)
except Exception as e:
raise BedrockError(status_code=500, message=str(e)) raise BedrockError(status_code=500, message=str(e))
response_body = json.loads(response.get('body').read()) response_body = json.loads(response.get("body").read())
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -491,16 +589,16 @@ def completion(
## RESPONSE OBJECT ## RESPONSE OBJECT
outputText = "default" outputText = "default"
if provider == "ai21": if provider == "ai21":
outputText = response_body.get('completions')[0].get('data').get('text') outputText = response_body.get("completions")[0].get("data").get("text")
elif provider == "anthropic": elif provider == "anthropic":
outputText = response_body['completion'] outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"] model_response["finish_reason"] = response_body["stop_reason"]
elif provider == "cohere": elif provider == "cohere":
outputText = response_body["generations"][0]["text"] outputText = response_body["generations"][0]["text"]
elif provider == "meta": elif provider == "meta":
outputText = response_body["generation"] outputText = response_body["generation"]
else: # amazon titan else: # amazon titan
outputText = response_body.get('results')[0].get('outputText') outputText = response_body.get("results")[0].get("outputText")
response_metadata = response.get("ResponseMetadata", {}) response_metadata = response.get("ResponseMetadata", {})
if response_metadata.get("HTTPStatusCode", 500) >= 400: if response_metadata.get("HTTPStatusCode", 500) >= 400:
@ -513,12 +611,13 @@ def completion(
if len(outputText) > 0: if len(outputText) > 0:
model_response["choices"][0]["message"]["content"] = outputText model_response["choices"][0]["message"]["content"] = outputText
except: except:
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500)) raise BedrockError(
message=json.dumps(outputText),
status_code=response_metadata.get("HTTPStatusCode", 500),
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -528,41 +627,47 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens = prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
except BedrockError as e: except BedrockError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if exception_mapping_worked: if exception_mapping_worked:
raise e raise e
else: else:
import traceback import traceback
raise BedrockError(status_code=500, message=traceback.format_exc()) raise BedrockError(status_code=500, message=traceback.format_exc())
def _embedding_func_single( def _embedding_func_single(
model: str, model: str,
input: str, input: str,
client: Any, client: Any,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
logging_obj=None, logging_obj=None,
): ):
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
## FORMAT EMBEDDING INPUT ## ## FORMAT EMBEDDING INPUT ##
provider = model.split(".")[0] provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
inference_params.pop("user", None) # make sure user is not passed in for bedrock call inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
if provider == "amazon": if provider == "amazon":
input = input.replace(os.linesep, " ") input = input.replace(os.linesep, " ")
data = {"inputText": input, **inference_params} data = {"inputText": input, **inference_params}
# data = json.dumps(data) # data = json.dumps(data)
elif provider == "cohere": elif provider == "cohere":
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 inference_params["input_type"] = inference_params.get(
data = {"texts": [input], **inference_params} # type: ignore "input_type", "search_document"
body = json.dumps(data).encode("utf-8") ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
data = {"texts": [input], **inference_params} # type: ignore
body = json.dumps(data).encode("utf-8")
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
response = client.invoke_model( response = client.invoke_model(
@ -570,12 +675,14 @@ def _embedding_func_single(
modelId={model}, modelId={model},
accept="*/*", accept="*/*",
contentType="application/json", contentType="application/json",
)""" # type: ignore )""" # type: ignore
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key="", # boto3 is used for init. api_key="", # boto3 is used for init.
additional_args={"complete_input_dict": {"model": model, additional_args={
"texts": input}, "request_str": request_str}, "complete_input_dict": {"model": model, "texts": input},
"request_str": request_str,
},
) )
try: try:
response = client.invoke_model( response = client.invoke_model(
@ -587,11 +694,11 @@ def _embedding_func_single(
response_body = json.loads(response.get("body").read()) response_body = json.loads(response.get("body").read())
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key="", api_key="",
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=json.dumps(response_body), original_response=json.dumps(response_body),
) )
if provider == "cohere": if provider == "cohere":
response = response_body.get("embeddings") response = response_body.get("embeddings")
# flatten list # flatten list
@ -600,7 +707,10 @@ def _embedding_func_single(
elif provider == "amazon": elif provider == "amazon":
return response_body.get("embedding") return response_body.get("embedding")
except Exception as e: except Exception as e:
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500) raise BedrockError(
message=f"Embedding Error with model {model}: {e}", status_code=500
)
def embedding( def embedding(
model: str, model: str,
@ -616,7 +726,9 @@ def embedding(
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None) aws_region_name = optional_params.pop("aws_region_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None) aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client( client = init_bedrock_client(
@ -624,11 +736,19 @@ def embedding(
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
) )
## Embedding Call ## Embedding Call
embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls embeddings = [
_embedding_func_single(
model,
i,
optional_params=optional_params,
client=client,
logging_obj=logging_obj,
)
for i in input
] # [TODO]: make these parallel calls
## Populate OpenAI compliant dictionary ## Populate OpenAI compliant dictionary
embedding_response = [] embedding_response = []
@ -647,13 +767,11 @@ def embedding(
input_str = "".join(input) input_str = "".join(input)
input_tokens+=len(encoding.encode(input_str)) input_tokens += len(encoding.encode(input_str))
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0
completion_tokens=0,
total_tokens=input_tokens + 0
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response

View file

@ -8,88 +8,106 @@ from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
import httpx import httpx
class CohereError(Exception): class CohereError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/generate") self.request = httpx.Request(
method="POST", url="https://api.cohere.ai/v1/generate"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class CohereConfig():
class CohereConfig:
""" """
Reference: https://docs.cohere.com/reference/generate Reference: https://docs.cohere.com/reference/generate
The class `CohereConfig` provides configuration for the Cohere's API interface. Below are the parameters: The class `CohereConfig` provides configuration for the Cohere's API interface. Below are the parameters:
- `num_generations` (integer): Maximum number of generations returned. Default is 1, with a minimum value of 1 and a maximum value of 5. - `num_generations` (integer): Maximum number of generations returned. Default is 1, with a minimum value of 1 and a maximum value of 5.
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default value is 20. - `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default value is 20.
- `truncate` (string): Specifies how the API handles inputs longer than maximum token length. Options include NONE, START, END. Default is END. - `truncate` (string): Specifies how the API handles inputs longer than maximum token length. Options include NONE, START, END. Default is END.
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.75. - `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.75.
- `preset` (string): Identifier of a custom preset, a combination of parameters such as prompt, temperature etc. - `preset` (string): Identifier of a custom preset, a combination of parameters such as prompt, temperature etc.
- `end_sequences` (array of strings): The generated text gets cut at the beginning of the earliest occurrence of an end sequence, which will be excluded from the text. - `end_sequences` (array of strings): The generated text gets cut at the beginning of the earliest occurrence of an end sequence, which will be excluded from the text.
- `stop_sequences` (array of strings): The generated text gets cut at the end of the earliest occurrence of a stop sequence, which will be included in the text. - `stop_sequences` (array of strings): The generated text gets cut at the end of the earliest occurrence of a stop sequence, which will be included in the text.
- `k` (integer): Limits generation at each step to top `k` most likely tokens. Default is 0. - `k` (integer): Limits generation at each step to top `k` most likely tokens. Default is 0.
- `p` (number): Limits generation at each step to most likely tokens with total probability mass of `p`. Default is 0. - `p` (number): Limits generation at each step to most likely tokens with total probability mass of `p`. Default is 0.
- `frequency_penalty` (number): Reduces repetitiveness of generated tokens. Higher values apply stronger penalties to previously occurred tokens. - `frequency_penalty` (number): Reduces repetitiveness of generated tokens. Higher values apply stronger penalties to previously occurred tokens.
- `presence_penalty` (number): Reduces repetitiveness of generated tokens. Similar to frequency_penalty, but this penalty applies equally to all tokens that have already appeared. - `presence_penalty` (number): Reduces repetitiveness of generated tokens. Similar to frequency_penalty, but this penalty applies equally to all tokens that have already appeared.
- `return_likelihoods` (string): Specifies how and if token likelihoods are returned with the response. Options include GENERATION, ALL and NONE. - `return_likelihoods` (string): Specifies how and if token likelihoods are returned with the response. Options include GENERATION, ALL and NONE.
- `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233} - `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233}
""" """
num_generations: Optional[int]=None
max_tokens: Optional[int]=None num_generations: Optional[int] = None
truncate: Optional[str]=None max_tokens: Optional[int] = None
temperature: Optional[int]=None truncate: Optional[str] = None
preset: Optional[str]=None temperature: Optional[int] = None
end_sequences: Optional[list]=None preset: Optional[str] = None
stop_sequences: Optional[list]=None end_sequences: Optional[list] = None
k: Optional[int]=None stop_sequences: Optional[list] = None
p: Optional[int]=None k: Optional[int] = None
frequency_penalty: Optional[int]=None p: Optional[int] = None
presence_penalty: Optional[int]=None frequency_penalty: Optional[int] = None
return_likelihoods: Optional[str]=None presence_penalty: Optional[int] = None
logit_bias: Optional[dict]=None return_likelihoods: Optional[str] = None
logit_bias: Optional[dict] = None
def __init__(self,
num_generations: Optional[int]=None, def __init__(
max_tokens: Optional[int]=None, self,
truncate: Optional[str]=None, num_generations: Optional[int] = None,
temperature: Optional[int]=None, max_tokens: Optional[int] = None,
preset: Optional[str]=None, truncate: Optional[str] = None,
end_sequences: Optional[list]=None, temperature: Optional[int] = None,
stop_sequences: Optional[list]=None, preset: Optional[str] = None,
k: Optional[int]=None, end_sequences: Optional[list] = None,
p: Optional[int]=None, stop_sequences: Optional[list] = None,
frequency_penalty: Optional[int]=None, k: Optional[int] = None,
presence_penalty: Optional[int]=None, p: Optional[int] = None,
return_likelihoods: Optional[str]=None, frequency_penalty: Optional[int] = None,
logit_bias: Optional[dict]=None) -> None: presence_penalty: Optional[int] = None,
return_likelihoods: Optional[str] = None,
logit_bias: Optional[dict] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
headers = { headers = {
@ -100,6 +118,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -119,9 +138,11 @@ def completion(
prompt = " ".join(message["content"] for message in messages) prompt = " ".join(message["content"] for message in messages)
## Load Config ## Load Config
config=litellm.CohereConfig.get_config() config = litellm.CohereConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { data = {
@ -132,16 +153,23 @@ def completion(
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url}, additional_args={
) "complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
) )
## error handling for cohere calls ## error handling for cohere calls
if response.status_code!=200: if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
@ -149,11 +177,11 @@ def completion(
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response.json() completion_response = response.json()
@ -168,18 +196,22 @@ def completion(
for idx, item in enumerate(completion_response["generations"]): for idx, item in enumerate(completion_response["generations"]):
if len(item["text"]) > 0: if len(item["text"]) > 0:
message_obj = Message(content=item["text"]) message_obj = Message(content=item["text"])
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -189,11 +221,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding( def embedding(
model: str, model: str,
input: list, input: list,
@ -206,11 +239,7 @@ def embedding(
headers = validate_environment(api_key) headers = validate_environment(api_key)
embed_url = "https://api.cohere.ai/v1/embed" embed_url = "https://api.cohere.ai/v1/embed"
model = model model = model
data = { data = {"model": model, "texts": input, **optional_params}
"model": model,
"texts": input,
**optional_params
}
if "3" in model and "input_type" not in data: if "3" in model and "input_type" not in data:
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document" # cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
@ -218,21 +247,19 @@ def embedding(
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
embed_url, headers=headers, data=json.dumps(data)
) )
## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=response, original_response=response,
) )
""" """
response response
{ {
@ -244,30 +271,23 @@ def embedding(
'usage' 'usage'
} }
""" """
if response.status_code!=200: if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
embeddings = response.json()['embeddings'] embeddings = response.json()["embeddings"]
output_data = [] output_data = []
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
output_data.append( output_data.append(
{ {"object": "embedding", "index": idx, "embedding": embedding}
"object": "embedding",
"index": idx,
"embedding": embedding
}
) )
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = output_data model_response["data"] = output_data
model_response["model"] = model model_response["model"] = model
input_tokens = 0 input_tokens = 0
for text in input: for text in input:
input_tokens+=len(encoding.encode(text)) input_tokens += len(encoding.encode(text))
model_response["usage"] = { model_response["usage"] = {
"prompt_tokens": input_tokens, "prompt_tokens": input_tokens,
"total_tokens": input_tokens, "total_tokens": input_tokens,
} }
return model_response return model_response

View file

@ -1,20 +1,24 @@
import time, json, httpx, asyncio import time, json, httpx, asyncio
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
""" """
Async implementation of custom http transport Async implementation of custom http transport
""" """
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[ if "images/generations" in request.url.path and request.url.params[
"api-version" "api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview", "2023-06-01-preview",
"2023-07-01-preview", "2023-07-01-preview",
"2023-08-01-preview", "2023-08-01-preview",
"2023-09-01-preview", "2023-09-01-preview",
"2023-10-01-preview", "2023-10-01-preview",
]: ]:
request.url = request.url.copy_with(path="/openai/images/generations:submit") request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
response = await super().handle_async_request(request) response = await super().handle_async_request(request)
operation_location_url = response.headers["operation-location"] operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url) request.url = httpx.URL(operation_location_url)
@ -26,7 +30,12 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
start_time = time.time() start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]: while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs: if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} timeout = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
return httpx.Response( return httpx.Response(
status_code=400, status_code=400,
headers=response.headers, headers=response.headers,
@ -56,26 +65,30 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
) )
return await super().handle_async_request(request) return await super().handle_async_request(request)
class CustomHTTPTransport(httpx.HTTPTransport): class CustomHTTPTransport(httpx.HTTPTransport):
""" """
This class was written as a workaround to support dall-e-2 on openai > v1.x This class was written as a workaround to support dall-e-2 on openai > v1.x
Refer to this issue for more: https://github.com/openai/openai-python/issues/692 Refer to this issue for more: https://github.com/openai/openai-python/issues/692
""" """
def handle_request( def handle_request(
self, self,
request: httpx.Request, request: httpx.Request,
) -> httpx.Response: ) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[ if "images/generations" in request.url.path and request.url.params[
"api-version" "api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict ] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview", "2023-06-01-preview",
"2023-07-01-preview", "2023-07-01-preview",
"2023-08-01-preview", "2023-08-01-preview",
"2023-09-01-preview", "2023-09-01-preview",
"2023-10-01-preview", "2023-10-01-preview",
]: ]:
request.url = request.url.copy_with(path="/openai/images/generations:submit") request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
response = super().handle_request(request) response = super().handle_request(request)
operation_location_url = response.headers["operation-location"] operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url) request.url = httpx.URL(operation_location_url)
@ -87,7 +100,12 @@ class CustomHTTPTransport(httpx.HTTPTransport):
start_time = time.time() start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]: while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs: if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} timeout = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
return httpx.Response( return httpx.Response(
status_code=400, status_code=400,
headers=response.headers, headers=response.headers,
@ -115,4 +133,4 @@ class CustomHTTPTransport(httpx.HTTPTransport):
content=json.dumps(result).encode("utf-8"), content=json.dumps(result).encode("utf-8"),
request=request, request=request,
) )
return super().handle_request(request) return super().handle_request(request)

View file

@ -8,17 +8,22 @@ import litellm
import sys, httpx import sys, httpx
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class GeminiError(Exception): class GeminiError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat") self.request = httpx.Request(
method="POST",
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class GeminiConfig():
class GeminiConfig:
""" """
Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig
@ -37,33 +42,44 @@ class GeminiConfig():
- `top_k` (int): Optional. The maximum number of tokens to consider when sampling. - `top_k` (int): Optional. The maximum number of tokens to consider when sampling.
""" """
candidate_count: Optional[int]=None candidate_count: Optional[int] = None
stop_sequences: Optional[list]=None stop_sequences: Optional[list] = None
max_output_tokens: Optional[int]=None max_output_tokens: Optional[int] = None
temperature: Optional[float]=None temperature: Optional[float] = None
top_p: Optional[float]=None top_p: Optional[float] = None
top_k: Optional[int]=None top_k: Optional[int] = None
def __init__(self, def __init__(
candidate_count: Optional[int]=None, self,
stop_sequences: Optional[list]=None, candidate_count: Optional[int] = None,
max_output_tokens: Optional[int]=None, stop_sequences: Optional[list] = None,
temperature: Optional[float]=None, max_output_tokens: Optional[int] = None,
top_p: Optional[float]=None, temperature: Optional[float] = None,
top_k: Optional[int]=None) -> None: top_p: Optional[float] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion( def completion(
@ -83,42 +99,50 @@ def completion(
try: try:
import google.generativeai as genai import google.generativeai as genai
except: except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="gemini") prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="gemini"
)
## Load Config ## Load Config
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py inference_params.pop(
config = litellm.GeminiConfig.get_config() "stream", None
for k, v in config.items(): ) # palm does not support streaming, so we handle this by fake streaming in main.py
if k not in inference_params: # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in config = litellm.GeminiConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": {"inference_params": inference_params}}, additional_args={"complete_input_dict": {"inference_params": inference_params}},
) )
## COMPLETION CALL ## COMPLETION CALL
try: try:
_model = genai.GenerativeModel(f'models/{model}') _model = genai.GenerativeModel(f"models/{model}")
response = _model.generate_content(contents=prompt, generation_config=genai.types.GenerationConfig(**inference_params)) response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
)
except Exception as e: except Exception as e:
raise GeminiError( raise GeminiError(
message=str(e), message=str(e),
@ -127,11 +151,11 @@ def completion(
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key="", api_key="",
original_response=response, original_response=response,
additional_args={"complete_input_dict": {}}, additional_args={"complete_input_dict": {}},
) )
print_verbose(f"raw model_response: {response}") print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response completion_response = response
@ -142,31 +166,34 @@ def completion(
message_obj = Message(content=item.content.parts[0].text) message_obj = Message(content=item.content.parts[0].text)
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(index=idx+1, message=message_obj) choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise GeminiError(message=traceback.format_exc(), status_code=response.status_code) raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
try: )
try:
completion_response = model_response["choices"][0]["message"].get("content") completion_response = model_response["choices"][0]["message"].get("content")
except: except:
raise GeminiError(status_code=400, message=f"No response received. Original response - {response}") raise GeminiError(
status_code=400,
message=f"No response received. Original response - {response}",
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_str = "" prompt_str = ""
for m in messages: for m in messages:
if isinstance(m["content"], str): if isinstance(m["content"], str):
prompt_str += m["content"] prompt_str += m["content"]
elif isinstance(m["content"], list): elif isinstance(m["content"], list):
for content in m["content"]: for content in m["content"]:
if content["type"] == "text": if content["type"] == "text":
prompt_str += content["text"] prompt_str += content["text"]
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt_str))
encoding.encode(prompt_str)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -174,13 +201,14 @@ def completion(
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "gemini/" + model model_response["model"] = "gemini/" + model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -11,32 +11,47 @@ from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper,
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
if request is not None: if request is not None:
self.request = request self.request = request
else: else:
self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models") self.request = httpx.Request(
method="POST", url="https://api-inference.huggingface.co/models"
)
if response is not None: if response is not None:
self.response = response self.response = response
else: else:
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class HuggingfaceConfig():
class HuggingfaceConfig:
""" """
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
""" """
best_of: Optional[int] = None best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = False # by default don't return the input as part of the output return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
seed: Optional[int] = None seed: Optional[int] = None
temperature: Optional[float] = None temperature: Optional[float] = None
top_k: Optional[int] = None top_k: Optional[int] = None
@ -46,50 +61,66 @@ class HuggingfaceConfig():
typical_p: Optional[float] = None typical_p: Optional[float] = None
watermark: Optional[bool] = None watermark: Optional[bool] = None
def __init__(self, def __init__(
best_of: Optional[int] = None, self,
decoder_input_details: Optional[bool] = None, best_of: Optional[int] = None,
details: Optional[bool] = None, decoder_input_details: Optional[bool] = None,
max_new_tokens: Optional[int] = None, details: Optional[bool] = None,
repetition_penalty: Optional[float] = None, max_new_tokens: Optional[int] = None,
return_full_text: Optional[bool] = None, repetition_penalty: Optional[float] = None,
seed: Optional[int] = None, return_full_text: Optional[bool] = None,
temperature: Optional[float] = None, seed: Optional[int] = None,
top_k: Optional[int] = None, temperature: Optional[float] = None,
top_n_tokens: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[int] = None, top_n_tokens: Optional[int] = None,
truncate: Optional[int] = None, top_p: Optional[int] = None,
typical_p: Optional[float] = None, truncate: Optional[int] = None,
watermark: Optional[bool] = None typical_p: Optional[float] = None,
) -> None: watermark: Optional[bool] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def output_parser(generated_text: str):
def output_parser(generated_text: str):
""" """
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
""" """
chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"] chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
for token in chat_template_tokens: for token in chat_template_tokens:
if generated_text.strip().startswith(token): if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1) generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token): if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text return generated_text
tgi_models_cache = None tgi_models_cache = None
conv_models_cache = None conv_models_cache = None
def read_tgi_conv_models(): def read_tgi_conv_models():
try: try:
global tgi_models_cache, conv_models_cache global tgi_models_cache, conv_models_cache
@ -101,30 +132,38 @@ def read_tgi_conv_models():
tgi_models = set() tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__)) script_directory = os.path.dirname(os.path.abspath(__file__))
# Construct the file path relative to the script's directory # Construct the file path relative to the script's directory
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_text_generation_models.txt") file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
with open(file_path, 'r') as file: with open(file_path, "r") as file:
for line in file: for line in file:
tgi_models.add(line.strip()) tgi_models.add(line.strip())
# Cache the set for future use # Cache the set for future use
tgi_models_cache = tgi_models tgi_models_cache = tgi_models
# If not, read the file and populate the cache # If not, read the file and populate the cache
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_conversational_models.txt") file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_conversational_models.txt",
)
conv_models = set() conv_models = set()
with open(file_path, 'r') as file: with open(file_path, "r") as file:
for line in file: for line in file:
conv_models.add(line.strip()) conv_models.add(line.strip())
# Cache the set for future use # Cache the set for future use
conv_models_cache = conv_models conv_models_cache = conv_models
return tgi_models, conv_models return tgi_models, conv_models
except: except:
return set(), set() return set(), set()
def get_hf_task_for_model(model): def get_hf_task_for_model(model):
# read text file, cast it to set # read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
tgi_models, conversational_models = read_tgi_conv_models() tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models: if model in tgi_models:
@ -134,9 +173,10 @@ def get_hf_task_for_model(model):
elif "roneneldan/TinyStories" in model: elif "roneneldan/TinyStories" in model:
return None return None
else: else:
return "text-generation-inference" # default to tgi return "text-generation-inference" # default to tgi
class Huggingface(BaseLLM):
class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None _client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None _aclient_session: Optional[httpx.AsyncClient] = None
@ -148,65 +188,93 @@ class Huggingface(BaseLLM):
"content-type": "application/json", "content-type": "application/json",
} }
if api_key and headers is None: if api_key and headers is None:
default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = default_headers headers = default_headers
elif headers: elif headers:
headers=headers headers = headers
else: else:
headers = default_headers headers = default_headers
return headers return headers
def convert_to_model_response_object(self, def convert_to_model_response_object(
completion_response, self,
model_response, completion_response,
task, model_response,
optional_params, task,
encoding, optional_params,
input_text, encoding,
model): input_text,
if task == "conversational": model,
if len(completion_response["generated_text"]) > 0: # type: ignore ):
if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore
model_response["choices"][0]["message"][ model_response["choices"][0]["message"][
"content" "content"
] = completion_response["generated_text"] # type: ignore ] = completion_response[
elif task == "text-generation-inference": "generated_text"
if (not isinstance(completion_response, list) ] # type: ignore
elif task == "text-generation-inference":
if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict) or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]): or "generated_text" not in completion_response[0]
raise HuggingfaceError(status_code=422, message=f"response is not in expected format - {completion_response}") ):
raise HuggingfaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
)
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = output_parser(
"content" completion_response[0]["generated_text"]
] = output_parser(completion_response[0]["generated_text"]) )
## GETTING LOGPROBS + FINISH REASON ## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: if (
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] "details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0 sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]: for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None: if token["logprob"] != None:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprob = sum_logprob model_response["choices"][0]["message"]._logprob = sum_logprob
if "best_of" in optional_params and optional_params["best_of"] > 1: if "best_of" in optional_params and optional_params["best_of"] > 1:
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]: if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = [] choices_list = []
for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]): for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0 sum_logprob = 0
for token in item["tokens"]: for token in item["tokens"]:
if token["logprob"] != None: if token["logprob"] != None:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0: if len(item["generated_text"]) > 0:
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob) message_obj = Message(
else: content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"].extend(choices_list) model_response["choices"].extend(choices_list)
else: else:
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = output_parser(
"content" completion_response[0]["generated_text"]
] = output_parser(completion_response[0]["generated_text"]) )
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = 0 prompt_tokens = 0
try: try:
@ -221,12 +289,14 @@ class Huggingface(BaseLLM):
completion_tokens = 0 completion_tokens = 0
try: try:
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here ) ##[TODO] use the llama2 tokenizer here
except: except:
# this should remain non blocking we should not block a response returning if calculating usage fails # this should remain non blocking we should not block a response returning if calculating usage fails
pass pass
else: else:
completion_tokens = 0 completion_tokens = 0
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
@ -234,13 +304,14 @@ class Huggingface(BaseLLM):
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
model_response._hidden_params["original_response"] = completion_response model_response._hidden_params["original_response"] = completion_response
return model_response return model_response
def completion(self, def completion(
self,
model: str, model: str,
messages: list, messages: list,
api_base: Optional[str], api_base: Optional[str],
@ -276,9 +347,11 @@ class Huggingface(BaseLLM):
completion_url = f"https://api-inference.huggingface.co/models/{model}" completion_url = f"https://api-inference.huggingface.co/models/{model}"
## Load Config ## Load Config
config=litellm.HuggingfaceConfig.get_config() config = litellm.HuggingfaceConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
### MAP INPUT PARAMS ### MAP INPUT PARAMS
@ -298,11 +371,11 @@ class Huggingface(BaseLLM):
generated_responses.append(message["content"]) generated_responses.append(message["content"])
data = { data = {
"inputs": { "inputs": {
"text": text, "text": text,
"past_user_inputs": past_user_inputs, "past_user_inputs": past_user_inputs,
"generated_responses": generated_responses "generated_responses": generated_responses,
}, },
"parameters": inference_params "parameters": inference_params,
} }
input_text = "".join(message["content"] for message in messages) input_text = "".join(message["content"] for message in messages)
elif task == "text-generation-inference": elif task == "text-generation-inference":
@ -311,29 +384,39 @@ class Huggingface(BaseLLM):
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None), role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), initial_prompt_value=model_prompt_details.get(
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), "initial_prompt_value", ""
messages=messages ),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": optional_params, "parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False, "stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
} }
input_text = prompt input_text = prompt
else: else:
# Non TGI and Conversational llms # Non TGI and Conversational llms
# We need this branch, it removes 'details' and 'return_full_text' from params # We need this branch, it removes 'details' and 'return_full_text' from params
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}), role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), initial_prompt_value=model_prompt_details.get(
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), "initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""), bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""), eos_token=model_prompt_details.get("eos_token", ""),
messages=messages, messages=messages,
@ -346,52 +429,68 @@ class Huggingface(BaseLLM):
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": inference_params, "parameters": inference_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False, "stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
} }
input_text = prompt input_text = prompt
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input_text, input=input_text,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url, "acompletion": acompletion}, additional_args={
) "complete_input_dict": data,
"task": task,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL ## COMPLETION CALL
if acompletion is True: if acompletion is True:
### ASYNC STREAMING ### ASYNC STREAMING
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model) # type: ignore
else: else:
### ASYNC COMPLETION ### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params) # type: ignore
### SYNC STREAMING ### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post( response = requests.post(
completion_url, completion_url,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream"] stream=optional_params["stream"],
) )
return response.iter_lines() return response.iter_lines()
### SYNC COMPLETION ### SYNC COMPLETION
else: else:
response = requests.post( response = requests.post(
completion_url, completion_url, headers=headers, data=json.dumps(data)
headers=headers,
data=json.dumps(data)
) )
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten) ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
is_streamed = False is_streamed = False
if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream": if (
response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True is_streamed = True
# iterate over the complete streamed response, and return the final answer # iterate over the complete streamed response, and return the final answer
if is_streamed: if is_streamed:
streamed_response = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="huggingface", logging_obj=logging_obj) streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = "" content = ""
for chunk in streamed_response: for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"] content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}] completion_response: List[Dict[str, Any]] = [
{"generated_text": content}
]
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input_text, input=input_text,
@ -399,7 +498,7 @@ class Huggingface(BaseLLM):
original_response=completion_response, original_response=completion_response,
additional_args={"complete_input_dict": data, "task": task}, additional_args={"complete_input_dict": data, "task": task},
) )
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input_text, input=input_text,
@ -410,15 +509,20 @@ class Huggingface(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = response.json() completion_response = response.json()
if isinstance(completion_response, dict): if isinstance(completion_response, dict):
completion_response = [completion_response] completion_response = [completion_response]
except: except:
import traceback import traceback
raise HuggingfaceError( raise HuggingfaceError(
message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}", status_code=response.status_code message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}",
status_code=response.status_code,
) )
print_verbose(f"response: {completion_response}") print_verbose(f"response: {completion_response}")
if isinstance(completion_response, dict) and "error" in completion_response: if (
isinstance(completion_response, dict)
and "error" in completion_response
):
print_verbose(f"completion error: {completion_response['error']}") print_verbose(f"completion error: {completion_response['error']}")
print_verbose(f"response.status_code: {response.status_code}") print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError( raise HuggingfaceError(
@ -432,75 +536,98 @@ class Huggingface(BaseLLM):
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
input_text=input_text, input_text=input_text,
model=model model=model,
) )
except HuggingfaceError as e: except HuggingfaceError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if exception_mapping_worked: if exception_mapping_worked:
raise e raise e
else: else:
import traceback import traceback
raise HuggingfaceError(status_code=500, message=traceback.format_exc()) raise HuggingfaceError(status_code=500, message=traceback.format_exc())
async def acompletion(self, async def acompletion(
api_base: str, self,
data: dict, api_base: str,
headers: dict, data: dict,
model_response: ModelResponse, headers: dict,
task: str, model_response: ModelResponse,
encoding: Any, task: str,
input_text: str, encoding: Any,
model: str, input_text: str,
optional_params: dict): model: str,
response = None optional_params: dict,
try: ):
response = None
try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(url=api_base, json=data, headers=headers, timeout=None) response = await client.post(
url=api_base, json=data, headers=headers, timeout=None
)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response) raise HuggingfaceError(
status_code=response.status_code,
message=response.text,
request=response.request,
response=response,
)
## RESPONSE OBJECT ## RESPONSE OBJECT
return self.convert_to_model_response_object(completion_response=response_json, return self.convert_to_model_response_object(
model_response=model_response, completion_response=response_json,
task=task, model_response=model_response,
encoding=encoding, task=task,
input_text=input_text, encoding=encoding,
model=model, input_text=input_text,
optional_params=optional_params) model=model,
except Exception as e: optional_params=optional_params,
if isinstance(e,httpx.TimeoutException): )
except Exception as e:
if isinstance(e, httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error") raise HuggingfaceError(status_code=500, message="Request Timeout Error")
elif response is not None and hasattr(response, "text"): elif response is not None and hasattr(response, "text"):
raise HuggingfaceError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") raise HuggingfaceError(
else: status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
)
else:
raise HuggingfaceError(status_code=500, message=f"{str(e)}") raise HuggingfaceError(status_code=500, message=f"{str(e)}")
async def async_streaming(self, async def async_streaming(
logging_obj, self,
api_base: str, logging_obj,
data: dict, api_base: str,
headers: dict, data: dict,
model_response: ModelResponse, headers: dict,
model: str): model_response: ModelResponse,
model: str,
):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = client.stream( response = client.stream(
"POST", "POST", url=f"{api_base}", json=data, headers=headers
url=f"{api_base}", )
json=data, async with response as r:
headers=headers
)
async with response as r:
if r.status_code != 200: if r.status_code != 200:
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming") raise HuggingfaceError(
status_code=r.status_code,
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj) message="An error occurred while streaming",
)
streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
def embedding(self, def embedding(
self,
model: str, model: str,
input: list, input: list,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -523,65 +650,70 @@ class Huggingface(BaseLLM):
embed_url = os.getenv("HUGGINGFACE_API_BASE", "") embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else: else:
embed_url = f"https://api-inference.huggingface.co/models/{model}" embed_url = f"https://api-inference.huggingface.co/models/{model}"
if "sentence-transformers" in model: if "sentence-transformers" in model:
if len(input) == 0: if len(input) == 0:
raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences") raise HuggingfaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = { data = {
"inputs": { "inputs": {
"source_sentence": input[0], "source_sentence": input[0],
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] "sentences": [
"That is a happy dog",
"That is a very happy person",
"Today is a sunny day",
],
} }
} }
else: else:
data = { data = {"inputs": input} # type: ignore
"inputs": input # type: ignore
}
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": embed_url}, additional_args={
) "complete_input_dict": data,
## COMPLETION CALL "headers": headers,
response = requests.post( "api_base": embed_url,
embed_url, headers=headers, data=json.dumps(data) },
) )
## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=response, original_response=response,
) )
embeddings = response.json() embeddings = response.json()
if "error" in embeddings: if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings['error']) raise HuggingfaceError(status_code=500, message=embeddings["error"])
output_data = [] output_data = []
if "similarities" in embeddings: if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]: for idx, embedding in embeddings["similarities"]:
output_data.append( output_data.append(
{ {
"object": "embedding", "object": "embedding",
"index": idx, "index": idx,
"embedding": embedding # flatten list returned from hf "embedding": embedding, # flatten list returned from hf
} }
) )
else: else:
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float): if isinstance(embedding, float):
output_data.append( output_data.append(
{ {
"object": "embedding", "object": "embedding",
"index": idx, "index": idx,
"embedding": embedding # flatten list returned from hf "embedding": embedding, # flatten list returned from hf
} }
) )
elif isinstance(embedding, list) and isinstance(embedding[0], float): elif isinstance(embedding, list) and isinstance(embedding[0], float):
@ -589,15 +721,17 @@ class Huggingface(BaseLLM):
{ {
"object": "embedding", "object": "embedding",
"index": idx, "index": idx,
"embedding": embedding # flatten list returned from hf "embedding": embedding, # flatten list returned from hf
} }
) )
else: else:
output_data.append( output_data.append(
{ {
"object": "embedding", "object": "embedding",
"index": idx, "index": idx,
"embedding": embedding[0][0] # flatten list returned from hf "embedding": embedding[0][
0
], # flatten list returned from hf
} }
) )
model_response["object"] = "list" model_response["object"] = "list"
@ -605,13 +739,10 @@ class Huggingface(BaseLLM):
model_response["model"] = model model_response["model"] = model
input_tokens = 0 input_tokens = 0
for text in input: for text in input:
input_tokens+=len(encoding.encode(text)) input_tokens += len(encoding.encode(text))
model_response["usage"] = { model_response["usage"] = {
"prompt_tokens": input_tokens, "prompt_tokens": input_tokens,
"total_tokens": input_tokens, "total_tokens": input_tokens,
} }
return model_response return model_response

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
class MaritalkError(Exception): class MaritalkError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -15,24 +16,26 @@ class MaritalkError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class MaritTalkConfig():
class MaritTalkConfig:
""" """
The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters: The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters:
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1. - `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1.
- `model` (string): The model used for conversation. Default is 'maritalk'. - `model` (string): The model used for conversation. Default is 'maritalk'.
- `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True. - `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True.
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7. - `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7.
- `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95. - `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95.
- `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1. - `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1.
- `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped. - `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped.
""" """
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
model: Optional[str] = None model: Optional[str] = None
do_sample: Optional[bool] = None do_sample: Optional[bool] = None
@ -41,27 +44,40 @@ class MaritTalkConfig():
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
stopping_tokens: Optional[List[str]] = None stopping_tokens: Optional[List[str]] = None
def __init__(self, def __init__(
max_tokens: Optional[int]=None, self,
model: Optional[str] = None, max_tokens: Optional[int] = None,
do_sample: Optional[bool] = None, model: Optional[str] = None,
temperature: Optional[float] = None, do_sample: Optional[bool] = None,
top_p: Optional[float] = None, temperature: Optional[float] = None,
repetition_penalty: Optional[float] = None, top_p: Optional[float] = None,
stopping_tokens: Optional[List[str]] = None) -> None: repetition_penalty: Optional[float] = None,
stopping_tokens: Optional[List[str]] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
headers = { headers = {
"accept": "application/json", "accept": "application/json",
@ -71,6 +87,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Key {api_key}" headers["Authorization"] = f"Key {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -89,9 +106,11 @@ def completion(
model = model model = model
## Load Config ## Load Config
config=litellm.MaritTalkConfig.get_config() config = litellm.MaritTalkConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { data = {
@ -101,24 +120,27 @@ def completion(
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response.json() completion_response = response.json()
@ -130,15 +152,17 @@ def completion(
else: else:
try: try:
if len(completion_response["answer"]) > 0: if len(completion_response["answer"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response["answer"] model_response["choices"][0]["message"][
"content"
] = completion_response["answer"]
except Exception as e: except Exception as e:
raise MaritalkError(message=response.text, status_code=response.status_code) raise MaritalkError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt = "".join(m["content"] for m in messages) prompt = "".join(m["content"] for m in messages)
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -148,11 +172,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding( def embedding(
model: str, model: str,
input: list, input: list,
@ -161,4 +186,4 @@ def embedding(
model_response=None, model_response=None,
encoding=None, encoding=None,
): ):
pass pass

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
class NLPCloudError(Exception): class NLPCloudError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -15,7 +16,8 @@ class NLPCloudError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class NLPCloudConfig():
class NLPCloudConfig:
""" """
Reference: https://docs.nlpcloud.com/#generation Reference: https://docs.nlpcloud.com/#generation
@ -43,45 +45,57 @@ class NLPCloudConfig():
- `num_return_sequences` (int): Optional. The number of independently computed returned sequences. - `num_return_sequences` (int): Optional. The number of independently computed returned sequences.
""" """
max_length: Optional[int]=None
length_no_input: Optional[bool]=None
end_sequence: Optional[str]=None
remove_end_sequence: Optional[bool]=None
remove_input: Optional[bool]=None
bad_words: Optional[list]=None
temperature: Optional[float]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
repetition_penalty: Optional[float]=None
num_beams: Optional[int]=None
num_return_sequences: Optional[int]=None
max_length: Optional[int] = None
length_no_input: Optional[bool] = None
end_sequence: Optional[str] = None
remove_end_sequence: Optional[bool] = None
remove_input: Optional[bool] = None
bad_words: Optional[list] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
num_beams: Optional[int] = None
num_return_sequences: Optional[int] = None
def __init__(self, def __init__(
max_length: Optional[int]=None, self,
length_no_input: Optional[bool]=None, max_length: Optional[int] = None,
end_sequence: Optional[str]=None, length_no_input: Optional[bool] = None,
remove_end_sequence: Optional[bool]=None, end_sequence: Optional[str] = None,
remove_input: Optional[bool]=None, remove_end_sequence: Optional[bool] = None,
bad_words: Optional[list]=None, remove_input: Optional[bool] = None,
temperature: Optional[float]=None, bad_words: Optional[list] = None,
top_p: Optional[float]=None, temperature: Optional[float] = None,
top_k: Optional[int]=None, top_p: Optional[float] = None,
repetition_penalty: Optional[float]=None, top_k: Optional[int] = None,
num_beams: Optional[int]=None, repetition_penalty: Optional[float] = None,
num_return_sequences: Optional[int]=None) -> None: num_beams: Optional[int] = None,
num_return_sequences: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
@ -93,6 +107,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Token {api_key}" headers["Authorization"] = f"Token {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -110,9 +125,11 @@ def completion(
headers = validate_environment(api_key) headers = validate_environment(api_key)
## Load Config ## Load Config
config = litellm.NLPCloudConfig.get_config() config = litellm.NLPCloudConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
completion_url_fragment_1 = api_base completion_url_fragment_1 = api_base
@ -129,24 +146,31 @@ def completion(
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=text, input=text,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url}, additional_args={
) "complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return clean_and_iterate_chunks(response) return clean_and_iterate_chunks(response)
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=text, input=text,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
@ -161,11 +185,16 @@ def completion(
else: else:
try: try:
if len(completion_response["generated_text"]) > 0: if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response["generated_text"] model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
except: except:
raise NLPCloudError(message=json.dumps(completion_response), status_code=response.status_code) raise NLPCloudError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = completion_response["nb_input_tokens"] prompt_tokens = completion_response["nb_input_tokens"]
completion_tokens = completion_response["nb_generated_tokens"] completion_tokens = completion_response["nb_generated_tokens"]
@ -174,7 +203,7 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
@ -187,25 +216,27 @@ def completion(
# # Perform further processing based on your needs # # Perform further processing based on your needs
# return cleaned_chunk # return cleaned_chunk
# for line in response.iter_lines(): # for line in response.iter_lines():
# if line: # if line:
# yield process_chunk(line) # yield process_chunk(line)
def clean_and_iterate_chunks(response): def clean_and_iterate_chunks(response):
buffer = b'' buffer = b""
for chunk in response.iter_content(chunk_size=1024): for chunk in response.iter_content(chunk_size=1024):
if not chunk: if not chunk:
break break
buffer += chunk buffer += chunk
while b'\x00' in buffer: while b"\x00" in buffer:
buffer = buffer.replace(b'\x00', b'') buffer = buffer.replace(b"\x00", b"")
yield buffer.decode('utf-8') yield buffer.decode("utf-8")
buffer = b'' buffer = b""
# No more data expected, yield any remaining data in the buffer # No more data expected, yield any remaining data in the buffer
if buffer: if buffer:
yield buffer.decode('utf-8') yield buffer.decode("utf-8")
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls

View file

@ -2,10 +2,11 @@ import requests, types, time
import json, uuid import json, uuid
import traceback import traceback
from typing import Optional from typing import Optional
import litellm import litellm
import httpx, aiohttp, asyncio import httpx, aiohttp, asyncio
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class OllamaError(Exception): class OllamaError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -16,14 +17,15 @@ class OllamaError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class OllamaConfig():
class OllamaConfig:
""" """
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters: The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
- `mirostat` (int): Enable Mirostat sampling for controlling perplexity. Default is 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0. Example usage: mirostat 0 - `mirostat` (int): Enable Mirostat sampling for controlling perplexity. Default is 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0. Example usage: mirostat 0
- `mirostat_eta` (float): Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. Default: 0.1. Example usage: mirostat_eta 0.1 - `mirostat_eta` (float): Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. Default: 0.1. Example usage: mirostat_eta 0.1
- `mirostat_tau` (float): Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. Default: 5.0. Example usage: mirostat_tau 5.0 - `mirostat_tau` (float): Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. Default: 5.0. Example usage: mirostat_tau 5.0
@ -56,102 +58,134 @@ class OllamaConfig():
- `template` (string): the full prompt or prompt template (overrides what is defined in the Modelfile) - `template` (string): the full prompt or prompt template (overrides what is defined in the Modelfile)
""" """
mirostat: Optional[int]=None
mirostat_eta: Optional[float]=None
mirostat_tau: Optional[float]=None
num_ctx: Optional[int]=None
num_gqa: Optional[int]=None
num_thread: Optional[int]=None
repeat_last_n: Optional[int]=None
repeat_penalty: Optional[float]=None
temperature: Optional[float]=None
stop: Optional[list]=None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
tfs_z: Optional[float]=None
num_predict: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
system: Optional[str]=None
template: Optional[str]=None
def __init__(self, mirostat: Optional[int] = None
mirostat: Optional[int]=None, mirostat_eta: Optional[float] = None
mirostat_eta: Optional[float]=None, mirostat_tau: Optional[float] = None
mirostat_tau: Optional[float]=None, num_ctx: Optional[int] = None
num_ctx: Optional[int]=None, num_gqa: Optional[int] = None
num_gqa: Optional[int]=None, num_thread: Optional[int] = None
num_thread: Optional[int]=None, repeat_last_n: Optional[int] = None
repeat_last_n: Optional[int]=None, repeat_penalty: Optional[float] = None
repeat_penalty: Optional[float]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, stop: Optional[
stop: Optional[list]=None, list
tfs_z: Optional[float]=None, ] = None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
num_predict: Optional[int]=None, tfs_z: Optional[float] = None
top_k: Optional[int]=None, num_predict: Optional[int] = None
top_p: Optional[float]=None, top_k: Optional[int] = None
system: Optional[str]=None, top_p: Optional[float] = None
template: Optional[str]=None) -> None: system: Optional[str] = None
template: Optional[str] = None
def __init__(
self,
mirostat: Optional[int] = None,
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,
num_ctx: Optional[int] = None,
num_gqa: Optional[int] = None,
num_thread: Optional[int] = None,
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
system: Optional[str] = None,
template: Optional[str] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# ollama implementation # ollama implementation
def get_ollama_response( def get_ollama_response(
api_base="http://localhost:11434", api_base="http://localhost:11434",
model="llama2", model="llama2",
prompt="Why is the sky blue?", prompt="Why is the sky blue?",
optional_params=None, optional_params=None,
logging_obj=None, logging_obj=None,
acompletion: bool = False, acompletion: bool = False,
model_response=None, model_response=None,
encoding=None encoding=None,
): ):
if api_base.endswith("/api/generate"): if api_base.endswith("/api/generate"):
url = api_base url = api_base
else: else:
url = f"{api_base}/api/generate" url = f"{api_base}/api/generate"
## Load Config ## Load Config
config=litellm.OllamaConfig.get_config() config = litellm.OllamaConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
optional_params["stream"] = optional_params.get("stream", False) optional_params["stream"] = optional_params.get("stream", False)
data = { data = {"model": model, "prompt": prompt, **optional_params}
"model": model,
"prompt": prompt,
**optional_params
}
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=None, input=None,
api_key=None, api_key=None,
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,}, additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
"acompletion": acompletion,
},
) )
if acompletion is True: if acompletion is True:
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:
response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj) response = ollama_async_streaming(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
else: else:
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj) response = ollama_acompletion(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
return response return response
elif optional_params.get("stream", False) == True: elif optional_params.get("stream", False) == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post( response = requests.post(
url=f"{url}", url=f"{url}",
json=data, json=data,
) )
if response.status_code != 200: if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text) raise OllamaError(status_code=response.status_code, message=response.text)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
@ -168,52 +202,76 @@ def get_ollama_response(
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
if optional_params.get("format", "") == "json": if optional_params.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}]) message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"arguments": response_json["response"], "name": ""},
"type": "function",
}
],
)
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
else: else:
model_response["choices"][0]["message"]["content"] = response_json["response"] model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model model_response["model"] = "ollama/" + model
prompt_tokens = response_json["prompt_eval_count"] # type: ignore prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"] completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response return model_response
def ollama_completion_stream(url, data, logging_obj): def ollama_completion_stream(url, data, logging_obj):
with httpx.stream( with httpx.stream(
url=url, url=url, json=data, method="POST", timeout=litellm.request_timeout
json=data, ) as response:
method="POST", try:
timeout=litellm.request_timeout
) as response:
try:
if response.status_code != 200: if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text) raise OllamaError(
status_code=response.status_code, message=response.text
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj) )
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: except Exception as e:
raise e raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
try: try:
client = httpx.AsyncClient() client = httpx.AsyncClient()
async with client.stream( async with client.stream(
url=f"{url}", url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
json=data, ) as response:
method="POST", if response.status_code != 200:
timeout=litellm.request_timeout raise OllamaError(
) as response: status_code=response.status_code, message=response.text
if response.status_code != 200: )
raise OllamaError(status_code=response.status_code, message=response.text)
streamwrapper = litellm.CustomStreamWrapper(
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj) completion_stream=response.aiter_lines(),
async for transformed_chunk in streamwrapper: model=data["model"],
yield transformed_chunk custom_llm_provider="ollama",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
async def ollama_acompletion(url, data, model_response, encoding, logging_obj): async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
data["stream"] = False data["stream"] = False
try: try:
@ -224,10 +282,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise OllamaError(status_code=resp.status, message=text) raise OllamaError(status_code=resp.status, message=text)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=data['prompt'], input=data["prompt"],
api_key="", api_key="",
original_response=resp.text, original_response=resp.text,
additional_args={ additional_args={
@ -240,37 +298,59 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
if data.get("format", "") == "json": if data.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}]) message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": response_json["response"],
"name": "",
},
"type": "function",
}
],
)
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
else: else:
model_response["choices"][0]["message"]["content"] = response_json["response"] model_response["choices"][0]["message"]["content"] = response_json[
"response"
]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data['model'] model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json["prompt_eval_count"] # type: ignore prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"] completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response return model_response
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise e raise e
async def ollama_aembeddings(api_base="http://localhost:11434", async def ollama_aembeddings(
model="llama2", api_base="http://localhost:11434",
prompt="Why is the sky blue?", model="llama2",
optional_params=None, prompt="Why is the sky blue?",
logging_obj=None, optional_params=None,
model_response=None, logging_obj=None,
encoding=None): model_response=None,
encoding=None,
):
if api_base.endswith("/api/embeddings"): if api_base.endswith("/api/embeddings"):
url = api_base url = api_base
else: else:
url = f"{api_base}/api/embeddings" url = f"{api_base}/api/embeddings"
## Load Config ## Load Config
config=litellm.OllamaConfig.get_config() config = litellm.OllamaConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { data = {
@ -290,7 +370,7 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
if response.status != 200: if response.status != 200:
text = await response.text() text = await response.text()
raise OllamaError(status_code=response.status, message=text) raise OllamaError(status_code=response.status, message=text)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
@ -308,20 +388,16 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
output_data = [] output_data = []
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
output_data.append( output_data.append(
{ {"object": "embedding", "index": idx, "embedding": embedding}
"object": "embedding",
"index": idx,
"embedding": embedding
}
) )
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = output_data model_response["data"] = output_data
model_response["model"] = model model_response["model"] = model
input_tokens = len(encoding.encode(prompt)) input_tokens = len(encoding.encode(prompt))
model_response["usage"] = { model_response["usage"] = {
"prompt_tokens": input_tokens, "prompt_tokens": input_tokens,
"total_tokens": input_tokens, "total_tokens": input_tokens,
} }
return model_response return model_response

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class OobaboogaError(Exception): class OobaboogaError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -15,6 +16,7 @@ class OobaboogaError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def validate_environment(api_key): def validate_environment(api_key):
headers = { headers = {
"accept": "application/json", "accept": "application/json",
@ -24,6 +26,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Token {api_key}" headers["Authorization"] = f"Token {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -44,21 +47,24 @@ def completion(
completion_url = model completion_url = model
elif api_base: elif api_base:
completion_url = api_base completion_url = api_base
else: else:
raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')") raise OobaboogaError(
status_code=404,
message="API Base not set. Set one via completion(..,api_base='your-api-url')",
)
model = model model = model
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
completion_url = completion_url + "/api/v1/generate" completion_url = completion_url + "/api/v1/generate"
data = { data = {
"prompt": prompt, "prompt": prompt,
@ -66,30 +72,35 @@ def completion(
} }
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = response.json() completion_response = response.json()
except: except:
raise OobaboogaError(message=response.text, status_code=response.status_code) raise OobaboogaError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response: if "error" in completion_response:
raise OobaboogaError( raise OobaboogaError(
message=completion_response["error"], message=completion_response["error"],
@ -97,14 +108,17 @@ def completion(
) )
else: else:
try: try:
model_response["choices"][0]["message"]["content"] = completion_response['results'][0]['text'] model_response["choices"][0]["message"][
"content"
] = completion_response["results"][0]["text"]
except: except:
raise OobaboogaError(message=json.dumps(completion_response), status_code=response.status_code) raise OobaboogaError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"]) encoding.encode(model_response["choices"][0]["message"]["content"])
) )
@ -114,11 +128,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

File diff suppressed because it is too large Load diff

View file

@ -1,30 +1,41 @@
from typing import List, Dict from typing import List, Dict
import types import types
class OpenrouterConfig():
class OpenrouterConfig:
""" """
Reference: https://openrouter.ai/docs#format Reference: https://openrouter.ai/docs#format
""" """
# OpenRouter-only parameters # OpenRouter-only parameters
extra_body: Dict[str, List[str]] = { extra_body: Dict[str, List[str]] = {"transforms": []} # default transforms to []
'transforms': [] # default transforms to []
}
def __init__(
def __init__(self, self,
transforms: List[str] = [], transforms: List[str] = [],
models: List[str] = [], models: List[str] = [],
route: str = '', route: str = "",
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}

View file

@ -7,17 +7,22 @@ from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
import litellm import litellm
import sys, httpx import sys, httpx
class PalmError(Exception): class PalmError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat") self.request = httpx.Request(
method="POST",
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class PalmConfig():
class PalmConfig:
""" """
Reference: https://developers.generativeai.google/api/python/google/generativeai/chat Reference: https://developers.generativeai.google/api/python/google/generativeai/chat
@ -37,35 +42,47 @@ class PalmConfig():
- `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output - `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output
""" """
context: Optional[str]=None
examples: Optional[list]=None
temperature: Optional[float]=None
candidate_count: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
max_output_tokens: Optional[int]=None
def __init__(self, context: Optional[str] = None
context: Optional[str]=None, examples: Optional[list] = None
examples: Optional[list]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, candidate_count: Optional[int] = None
candidate_count: Optional[int]=None, top_k: Optional[int] = None
top_k: Optional[int]=None, top_p: Optional[float] = None
top_p: Optional[float]=None, max_output_tokens: Optional[int] = None
max_output_tokens: Optional[int]=None) -> None:
def __init__(
self,
context: Optional[str] = None,
examples: Optional[list] = None,
temperature: Optional[float] = None,
candidate_count: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
max_output_tokens: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion( def completion(
@ -83,41 +100,43 @@ def completion(
try: try:
import google.generativeai as palm import google.generativeai as palm
except: except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
palm.configure(api_key=api_key) palm.configure(api_key=api_key)
model = model model = model
## Load Config ## Load Config
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py inference_params.pop(
config = litellm.PalmConfig.get_config() "stream", None
for k, v in config.items(): ) # palm does not support streaming, so we handle this by fake streaming in main.py
if k not in inference_params: # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in config = litellm.PalmConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
prompt = "" prompt = ""
for message in messages: for message in messages:
if "role" in message: if "role" in message:
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": {"inference_params": inference_params}}, additional_args={"complete_input_dict": {"inference_params": inference_params}},
) )
## COMPLETION CALL ## COMPLETION CALL
try: try:
response = palm.generate_text(prompt=prompt, **inference_params) response = palm.generate_text(prompt=prompt, **inference_params)
except Exception as e: except Exception as e:
raise PalmError( raise PalmError(
@ -127,11 +146,11 @@ def completion(
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key="", api_key="",
original_response=response, original_response=response,
additional_args={"complete_input_dict": {}}, additional_args={"complete_input_dict": {}},
) )
print_verbose(f"raw model_response: {response}") print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response completion_response = response
@ -142,22 +161,25 @@ def completion(
message_obj = Message(content=item["output"]) message_obj = Message(content=item["output"])
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(index=idx+1, message=message_obj) choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise PalmError(message=traceback.format_exc(), status_code=response.status_code) raise PalmError(
message=traceback.format_exc(), status_code=response.status_code
try: )
try:
completion_response = model_response["choices"][0]["message"].get("content") completion_response = model_response["choices"][0]["message"].get("content")
except: except:
raise PalmError(status_code=400, message=f"No response received. Original response - {response}") raise PalmError(
status_code=400,
message=f"No response received. Original response - {response}",
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -165,13 +187,14 @@ def completion(
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "palm/" + model model_response["model"] = "palm/" + model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -8,6 +8,7 @@ import litellm
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class PetalsError(Exception): class PetalsError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -16,7 +17,8 @@ class PetalsError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class PetalsConfig():
class PetalsConfig:
""" """
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below: The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
@ -30,45 +32,64 @@ class PetalsConfig():
- `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below: - `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
- `temperature` (float, optional): This value sets the temperature for sampling. - `temperature` (float, optional): This value sets the temperature for sampling.
- `top_k` (integer, optional): This value sets the limit for top-k sampling. - `top_k` (integer, optional): This value sets the limit for top-k sampling.
- `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling. - `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper. - `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
""" """
max_length: Optional[int]=None
max_new_tokens: Optional[int]=litellm.max_tokens # petals requires max tokens to be set
do_sample: Optional[bool]=None
temperature: Optional[float]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
repetition_penalty: Optional[float]=None
def __init__(self, max_length: Optional[int] = None
max_length: Optional[int]=None, max_new_tokens: Optional[
max_new_tokens: Optional[int]=litellm.max_tokens, # petals requires max tokens to be set int
do_sample: Optional[bool]=None, ] = litellm.max_tokens # petals requires max tokens to be set
temperature: Optional[float]=None, do_sample: Optional[bool] = None
top_k: Optional[int]=None, temperature: Optional[float] = None
top_p: Optional[float]=None, top_k: Optional[int] = None
repetition_penalty: Optional[float]=None) -> None: top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
def __init__(
self,
max_length: Optional[int] = None,
max_new_tokens: Optional[
int
] = litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
api_base: Optional[str], api_base: Optional[str],
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
@ -80,96 +101,97 @@ def completion(
): ):
## Load Config ## Load Config
config = litellm.PetalsConfig.get_config() config = litellm.PetalsConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
if model in litellm.custom_prompt_dict: if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model] model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
if api_base: if api_base:
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": optional_params, "api_base": api_base}, additional_args={
) "complete_input_dict": optional_params,
data = { "api_base": api_base,
"model": model, },
"inputs": prompt, )
**optional_params data = {"model": model, "inputs": prompt, **optional_params}
}
## COMPLETION CALL ## COMPLETION CALL
response = requests.post(api_base, data=data) response = requests.post(api_base, data=data)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key="", api_key="",
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": optional_params}, additional_args={"complete_input_dict": optional_params},
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
output_text = response.json()["outputs"] output_text = response.json()["outputs"]
except Exception as e: except Exception as e:
PetalsError(status_code=response.status_code, message=str(e)) PetalsError(status_code=response.status_code, message=str(e))
else: else:
try: try:
import torch import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM # type: ignore from petals import AutoDistributedModelForCausalLM # type: ignore
except: except:
raise Exception( raise Exception(
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals" "Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
) )
model = model model = model
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, add_bos_token=False) tokenizer = AutoTokenizer.from_pretrained(
model, use_fast=False, add_bos_token=False
)
model_obj = AutoDistributedModelForCausalLM.from_pretrained(model) model_obj = AutoDistributedModelForCausalLM.from_pretrained(model)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": optional_params}, additional_args={"complete_input_dict": optional_params},
) )
## COMPLETION CALL ## COMPLETION CALL
inputs = tokenizer(prompt, return_tensors="pt")["input_ids"] inputs = tokenizer(prompt, return_tensors="pt")["input_ids"]
# optional params: max_new_tokens=1,temperature=0.9, top_p=0.6 # optional params: max_new_tokens=1,temperature=0.9, top_p=0.6
outputs = model_obj.generate(inputs, **optional_params) outputs = model_obj.generate(inputs, **optional_params)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key="", api_key="",
original_response=outputs, original_response=outputs,
additional_args={"complete_input_dict": optional_params}, additional_args={"complete_input_dict": optional_params},
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
output_text = tokenizer.decode(outputs[0]) output_text = tokenizer.decode(outputs[0])
if len(output_text) > 0: if len(output_text) > 0:
model_response["choices"][0]["message"]["content"] = output_text model_response["choices"][0]["message"]["content"] = output_text
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content")) encoding.encode(model_response["choices"][0]["message"].get("content"))
) )
@ -177,13 +199,14 @@ def completion(
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -4,11 +4,13 @@ import json
from jinja2 import Template, exceptions, Environment, meta from jinja2 import Template, exceptions, Environment, meta
from typing import Optional, Any from typing import Optional, Any
def default_pt(messages): def default_pt(messages):
return " ".join(message["content"] for message in messages) return " ".join(message["content"] for message in messages)
# alpaca prompt template - for models like mythomax, etc.
def alpaca_pt(messages): # alpaca prompt template - for models like mythomax, etc.
def alpaca_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(
role_dict={ role_dict={
"system": { "system": {
@ -19,59 +21,56 @@ def alpaca_pt(messages):
"pre_message": "### Instruction:\n", "pre_message": "### Instruction:\n",
"post_message": "\n\n", "post_message": "\n\n",
}, },
"assistant": { "assistant": {"pre_message": "### Response:\n", "post_message": "\n\n"},
"pre_message": "### Response:\n",
"post_message": "\n\n"
}
}, },
bos_token="<s>", bos_token="<s>",
eos_token="</s>", eos_token="</s>",
messages=messages messages=messages,
) )
return prompt return prompt
# Llama2 prompt template # Llama2 prompt template
def llama_2_chat_pt(messages): def llama_2_chat_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(
role_dict={ role_dict={
"system": { "system": {
"pre_message": "[INST] <<SYS>>\n", "pre_message": "[INST] <<SYS>>\n",
"post_message": "\n<</SYS>>\n [/INST]\n" "post_message": "\n<</SYS>>\n [/INST]\n",
}, },
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348 "user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
"pre_message": "[INST] ", "pre_message": "[INST] ",
"post_message": " [/INST]\n" "post_message": " [/INST]\n",
}, },
"assistant": { "assistant": {
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama "post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
} },
}, },
messages=messages, messages=messages,
bos_token="<s>", bos_token="<s>",
eos_token="</s>" eos_token="</s>",
) )
return prompt return prompt
def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
def ollama_pt(
if "instruct" in model: model, messages
): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
if "instruct" in model:
prompt = custom_prompt( prompt = custom_prompt(
role_dict={ role_dict={
"system": { "system": {"pre_message": "### System:\n", "post_message": "\n"},
"pre_message": "### System:\n",
"post_message": "\n"
},
"user": { "user": {
"pre_message": "### User:\n", "pre_message": "### User:\n",
"post_message": "\n", "post_message": "\n",
}, },
"assistant": { "assistant": {
"pre_message": "### Response:\n", "pre_message": "### Response:\n",
"post_message": "\n", "post_message": "\n",
} },
}, },
final_prompt_value="### Response:", final_prompt_value="### Response:",
messages=messages messages=messages,
) )
elif "llava" in model: elif "llava" in model:
prompt = "" prompt = ""
@ -88,36 +87,31 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
elif element["type"] == "image_url": elif element["type"] == "image_url":
image_url = element["image_url"]["url"] image_url = element["image_url"]["url"]
images.append(image_url) images.append(image_url)
return { return {"prompt": prompt, "images": images}
"prompt": prompt, else:
"images": images prompt = "".join(
} m["content"]
else: if isinstance(m["content"], str) is str
prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages) else "".join(m["content"])
for m in messages
)
return prompt return prompt
def mistral_instruct_pt(messages):
def mistral_instruct_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(
initial_prompt_value="<s>", initial_prompt_value="<s>",
role_dict={ role_dict={
"system": { "system": {"pre_message": "[INST]", "post_message": "[/INST]"},
"pre_message": "[INST]", "user": {"pre_message": "[INST]", "post_message": "[/INST]"},
"post_message": "[/INST]" "assistant": {"pre_message": "[INST]", "post_message": "[/INST]"},
},
"user": {
"pre_message": "[INST]",
"post_message": "[/INST]"
},
"assistant": {
"pre_message": "[INST]",
"post_message": "[/INST]"
}
}, },
final_prompt_value="</s>", final_prompt_value="</s>",
messages=messages messages=messages,
) )
return prompt return prompt
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 # Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages): def falcon_instruct_pt(messages):
prompt = "" prompt = ""
@ -125,11 +119,16 @@ def falcon_instruct_pt(messages):
if message["role"] == "system": if message["role"] == "system":
prompt += message["content"] prompt += message["content"]
else: else:
prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n") prompt += (
message["role"]
+ ":"
+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
)
prompt += "\n\n" prompt += "\n\n"
return prompt return prompt
def falcon_chat_pt(messages): def falcon_chat_pt(messages):
prompt = "" prompt = ""
for message in messages: for message in messages:
@ -142,6 +141,7 @@ def falcon_chat_pt(messages):
return prompt return prompt
# MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 # MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def mpt_chat_pt(messages): def mpt_chat_pt(messages):
prompt = "" prompt = ""
@ -154,18 +154,20 @@ def mpt_chat_pt(messages):
prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n" prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n"
return prompt return prompt
# WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format # WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format
def wizardcoder_pt(messages): def wizardcoder_pt(messages):
prompt = "" prompt = ""
for message in messages: for message in messages:
if message["role"] == "system": if message["role"] == "system":
prompt += message["content"] + "\n\n" prompt += message["content"] + "\n\n"
elif message["role"] == "user": # map to 'Instruction' elif message["role"] == "user": # map to 'Instruction'
prompt += "### Instruction:\n" + message["content"] + "\n\n" prompt += "### Instruction:\n" + message["content"] + "\n\n"
elif message["role"] == "assistant": # map to 'Response' elif message["role"] == "assistant": # map to 'Response'
prompt += "### Response:\n" + message["content"] + "\n\n" prompt += "### Response:\n" + message["content"] + "\n\n"
return prompt return prompt
# Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model # Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model
def phind_codellama_pt(messages): def phind_codellama_pt(messages):
prompt = "" prompt = ""
@ -178,13 +180,17 @@ def phind_codellama_pt(messages):
prompt += "### Assistant\n" + message["content"] + "\n\n" prompt += "### Assistant\n" + message["content"] + "\n\n"
return prompt return prompt
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None):
## get the tokenizer config from huggingface ## get the tokenizer config from huggingface
bos_token = "" bos_token = ""
eos_token = "" eos_token = ""
if chat_template is None: if chat_template is None:
def _get_tokenizer_config(hf_model_name): def _get_tokenizer_config(hf_model_name):
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" url = (
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
)
# Make a GET request to fetch the JSON data # Make a GET request to fetch the JSON data
response = requests.get(url) response = requests.get(url)
if response.status_code == 200: if response.status_code == 200:
@ -193,10 +199,14 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
return {"status": "success", "tokenizer": tokenizer_config} return {"status": "success", "tokenizer": tokenizer_config}
else: else:
return {"status": "failure"} return {"status": "failure"}
tokenizer_config = _get_tokenizer_config(model) tokenizer_config = _get_tokenizer_config(model)
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]: if (
tokenizer_config["status"] == "failure"
or "chat_template" not in tokenizer_config["tokenizer"]
):
raise Exception("No chat template found") raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json ## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"] tokenizer_config = tokenizer_config["tokenizer"]
bos_token = tokenizer_config["bos_token"] bos_token = tokenizer_config["bos_token"]
eos_token = tokenizer_config["eos_token"] eos_token = tokenizer_config["eos_token"]
@ -204,10 +214,10 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
def raise_exception(message): def raise_exception(message):
raise Exception(f"Error message - {message}") raise Exception(f"Error message - {message}")
# Create a template object from the template text # Create a template object from the template text
env = Environment() env = Environment()
env.globals['raise_exception'] = raise_exception env.globals["raise_exception"] = raise_exception
try: try:
template = env.from_string(chat_template) template = env.from_string(chat_template)
except Exception as e: except Exception as e:
@ -216,137 +226,167 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
def _is_system_in_template(): def _is_system_in_template():
try: try:
# Try rendering the template with a system message # Try rendering the template with a system message
response = template.render(messages=[{"role": "system", "content": "test"}], eos_token= "<eos>", bos_token= "<bos>") response = template.render(
messages=[{"role": "system", "content": "test"}],
eos_token="<eos>",
bos_token="<bos>",
)
return True return True
# This will be raised if Jinja attempts to render the system message and it can't # This will be raised if Jinja attempts to render the system message and it can't
except: except:
return False return False
try: try:
# Render the template with the provided values # Render the template with the provided values
if _is_system_in_template(): if _is_system_in_template():
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=messages) rendered_text = template.render(
else: bos_token=bos_token, eos_token=eos_token, messages=messages
)
else:
# treat a system message as a user message, if system not in template # treat a system message as a user message, if system not in template
try: try:
reformatted_messages = [] reformatted_messages = []
for message in messages: for message in messages:
if message["role"] == "system": if message["role"] == "system":
reformatted_messages.append({"role": "user", "content": message["content"]}) reformatted_messages.append(
{"role": "user", "content": message["content"]}
)
else: else:
reformatted_messages.append(message) reformatted_messages.append(message)
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=reformatted_messages) rendered_text = template.render(
bos_token=bos_token,
eos_token=eos_token,
messages=reformatted_messages,
)
except Exception as e: except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e): if "Conversation roles must alternate user/assistant" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = [] new_messages = []
for i in range(len(reformatted_messages)-1): for i in range(len(reformatted_messages) - 1):
new_messages.append(reformatted_messages[i]) new_messages.append(reformatted_messages[i])
if reformatted_messages[i]["role"] == reformatted_messages[i+1]["role"]: if (
reformatted_messages[i]["role"]
== reformatted_messages[i + 1]["role"]
):
if reformatted_messages[i]["role"] == "user": if reformatted_messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""}) new_messages.append(
{"role": "assistant", "content": ""}
)
else: else:
new_messages.append({"role": "user", "content": ""}) new_messages.append({"role": "user", "content": ""})
new_messages.append(reformatted_messages[-1]) new_messages.append(reformatted_messages[-1])
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages) rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=new_messages
)
return rendered_text return rendered_text
except Exception as e: except Exception as e:
raise Exception(f"Error rendering template - {str(e)}") raise Exception(f"Error rendering template - {str(e)}")
# Anthropic template
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts # Anthropic template
def claude_2_1_pt(
messages: list,
): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
""" """
Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human: Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human:
- you can't just pass a system message - you can't just pass a system message
- you can't pass a system message and follow that with an assistant message - you can't pass a system message and follow that with an assistant message
if system message is passed in, you can only do system, human, assistant or system, human if system message is passed in, you can only do system, human, assistant or system, human
if a system message is passed in and followed by an assistant message, insert a blank human message between them. if a system message is passed in and followed by an assistant message, insert a blank human message between them.
""" """
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: " AI_PROMPT = "\n\nAssistant: "
prompt = "" prompt = ""
for idx, message in enumerate(messages): for idx, message in enumerate(messages):
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
)
elif message["role"] == "system": elif message["role"] == "system":
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
elif message["role"] == "assistant": elif message["role"] == "assistant":
if idx > 0 and messages[idx - 1]["role"] == "system": if idx > 0 and messages[idx - 1]["role"] == "system":
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message
prompt += ( prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}" prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
)
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
return prompt return prompt
### TOGETHER AI
### TOGETHER AI
def get_model_info(token, model): def get_model_info(token, model):
try: try:
headers = { headers = {"Authorization": f"Bearer {token}"}
'Authorization': f'Bearer {token}' response = requests.get("https://api.together.xyz/models/info", headers=headers)
}
response = requests.get('https://api.together.xyz/models/info', headers=headers)
if response.status_code == 200: if response.status_code == 200:
model_info = response.json() model_info = response.json()
for m in model_info: for m in model_info:
if m["name"].lower().strip() == model.strip(): if m["name"].lower().strip() == model.strip():
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None) return m["config"].get("prompt_format", None), m["config"].get(
"chat_template", None
)
return None, None return None, None
else: else:
return None, None return None, None
except Exception as e: # safely fail a prompt template request except Exception as e: # safely fail a prompt template request
return None, None return None, None
def format_prompt_togetherai(messages, prompt_format, chat_template): def format_prompt_togetherai(messages, prompt_format, chat_template):
if prompt_format is None: if prompt_format is None:
return default_pt(messages) return default_pt(messages)
human_prompt, assistant_prompt = prompt_format.split('{prompt}') human_prompt, assistant_prompt = prompt_format.split("{prompt}")
if chat_template is not None: if chat_template is not None:
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template) prompt = hf_chat_template(
elif prompt_format is not None: model=None, messages=messages, chat_template=chat_template
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt) )
else: elif prompt_format is not None:
prompt = custom_prompt(
role_dict={},
messages=messages,
initial_prompt_value=human_prompt,
final_prompt_value=assistant_prompt,
)
else:
prompt = default_pt(messages) prompt = default_pt(messages)
return prompt return prompt
### ###
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
def anthropic_pt(
messages: list,
): # format - https://docs.anthropic.com/claude/reference/complete_post
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: " AI_PROMPT = "\n\nAssistant: "
prompt = "" prompt = ""
for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: ` for idx, message in enumerate(
messages
): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
)
elif message["role"] == "system": elif message["role"] == "system":
prompt += ( prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
)
else: else:
prompt += ( prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}" if (
) idx == 0 and message["role"] == "assistant"
if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: ` ): # ensure the prompt always starts with `\n\nHuman: `
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt
prompt += f"{AnthropicConstants.AI_PROMPT.value}" prompt += f"{AnthropicConstants.AI_PROMPT.value}"
return prompt return prompt
def gemini_text_image_pt(messages: list):
def gemini_text_image_pt(messages: list):
""" """
{ {
"contents":[ "contents":[
@ -367,13 +407,15 @@ def gemini_text_image_pt(messages: list):
try: try:
import google.generativeai as genai import google.generativeai as genai
except: except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
prompt = "" prompt = ""
images = [] images = []
for message in messages: for message in messages:
if isinstance(message["content"], str): if isinstance(message["content"], str):
prompt += message["content"] prompt += message["content"]
elif isinstance(message["content"], list): elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]: for element in message["content"]:
@ -383,45 +425,63 @@ def gemini_text_image_pt(messages: list):
elif element["type"] == "image_url": elif element["type"] == "image_url":
image_url = element["image_url"]["url"] image_url = element["image_url"]["url"]
images.append(image_url) images.append(image_url)
content = [prompt] + images content = [prompt] + images
return content return content
# Function call template
# Function call template
def function_call_prompt(messages: list, functions: list): def function_call_prompt(messages: list, functions: list):
function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:" function_prompt = (
for function in functions: "Produce JSON OUTPUT ONLY! The following functions are available to you:"
)
for function in functions:
function_prompt += f"""\n{function}\n""" function_prompt += f"""\n{function}\n"""
function_added_to_prompt = False function_added_to_prompt = False
for message in messages: for message in messages:
if "system" in message["role"]: if "system" in message["role"]:
message['content'] += f"""{function_prompt}""" message["content"] += f"""{function_prompt}"""
function_added_to_prompt = True function_added_to_prompt = True
if function_added_to_prompt == False: if function_added_to_prompt == False:
messages.append({'role': 'system', 'content': f"""{function_prompt}"""}) messages.append({"role": "system", "content": f"""{function_prompt}"""})
return messages return messages
# Custom prompt template # Custom prompt template
def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str="", bos_token: str="", eos_token: str=""): def custom_prompt(
role_dict: dict,
messages: list,
initial_prompt_value: str = "",
final_prompt_value: str = "",
bos_token: str = "",
eos_token: str = "",
):
prompt = bos_token + initial_prompt_value prompt = bos_token + initial_prompt_value
bos_open = True bos_open = True
## a bos token is at the start of a system / human message ## a bos token is at the start of a system / human message
## an eos token is at the end of the assistant response to the message ## an eos token is at the end of the assistant response to the message
for message in messages: for message in messages:
role = message["role"] role = message["role"]
if role in ["system", "human"] and not bos_open: if role in ["system", "human"] and not bos_open:
prompt += bos_token prompt += bos_token
bos_open = True bos_open = True
pre_message_str = role_dict[role]["pre_message"] if role in role_dict and "pre_message" in role_dict[role] else "" pre_message_str = (
post_message_str = role_dict[role]["post_message"] if role in role_dict and "post_message" in role_dict[role] else "" role_dict[role]["pre_message"]
if role in role_dict and "pre_message" in role_dict[role]
else ""
)
post_message_str = (
role_dict[role]["post_message"]
if role in role_dict and "post_message" in role_dict[role]
else ""
)
prompt += pre_message_str + message["content"] + post_message_str prompt += pre_message_str + message["content"] + post_message_str
if role == "assistant": if role == "assistant":
prompt += eos_token prompt += eos_token
bos_open = False bos_open = False
@ -429,25 +489,35 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += final_prompt_value prompt += final_prompt_value
return prompt return prompt
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
def prompt_factory(
model: str,
messages: list,
custom_llm_provider: Optional[str] = None,
api_key: Optional[str] = None,
):
original_model_name = model original_model_name = model
model = model.lower() model = model.lower()
if custom_llm_provider == "ollama": if custom_llm_provider == "ollama":
return ollama_pt(model=model, messages=messages) return ollama_pt(model=model, messages=messages)
elif custom_llm_provider == "anthropic": elif custom_llm_provider == "anthropic":
if "claude-2.1" in model: if "claude-2.1" in model:
return claude_2_1_pt(messages=messages) return claude_2_1_pt(messages=messages)
else: else:
return anthropic_pt(messages=messages) return anthropic_pt(messages=messages)
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model) prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template) return format_prompt_togetherai(
elif custom_llm_provider == "gemini": messages=messages, prompt_format=prompt_format, chat_template=chat_template
)
elif custom_llm_provider == "gemini":
return gemini_text_image_pt(messages=messages) return gemini_text_image_pt(messages=messages)
try: try:
if "meta-llama/llama-2" in model and "chat" in model: if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages) return llama_2_chat_pt(messages=messages)
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template. elif (
"tiiuae/falcon" in model
): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
if model == "tiiuae/falcon-180B-chat": if model == "tiiuae/falcon-180B-chat":
return falcon_chat_pt(messages=messages) return falcon_chat_pt(messages=messages)
elif "instruct" in model: elif "instruct" in model:
@ -457,17 +527,26 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return mpt_chat_pt(messages=messages) return mpt_chat_pt(messages=messages)
elif "codellama/codellama" in model or "togethercomputer/codellama" in model: elif "codellama/codellama" in model or "togethercomputer/codellama" in model:
if "instruct" in model: if "instruct" in model:
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions return llama_2_chat_pt(
messages=messages
) # https://huggingface.co/blog/codellama#conversational-instructions
elif "wizardlm/wizardcoder" in model: elif "wizardlm/wizardcoder" in model:
return wizardcoder_pt(messages=messages) return wizardcoder_pt(messages=messages)
elif "phind/phind-codellama" in model: elif "phind/phind-codellama" in model:
return phind_codellama_pt(messages=messages) return phind_codellama_pt(messages=messages)
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model): elif "togethercomputer/llama-2" in model and (
"instruct" in model or "chat" in model
):
return llama_2_chat_pt(messages=messages) return llama_2_chat_pt(messages=messages)
elif model in ["gryphe/mythomax-l2-13b", "gryphe/mythomix-l2-13b", "gryphe/mythologic-l2-13b"]: elif model in [
return alpaca_pt(messages=messages) "gryphe/mythomax-l2-13b",
else: "gryphe/mythomix-l2-13b",
"gryphe/mythologic-l2-13b",
]:
return alpaca_pt(messages=messages)
else:
return hf_chat_template(original_model_name, messages) return hf_chat_template(original_model_name, messages)
except Exception as e: except Exception as e:
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) return default_pt(
messages=messages
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

View file

@ -4,81 +4,100 @@ import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
import litellm import litellm
import httpx import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class ReplicateError(Exception): class ReplicateError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.replicate.com/v1/deployments") self.request = httpx.Request(
method="POST", url="https://api.replicate.com/v1/deployments"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class ReplicateConfig():
class ReplicateConfig:
""" """
Reference: https://replicate.com/meta/llama-2-70b-chat/api Reference: https://replicate.com/meta/llama-2-70b-chat/api
- `prompt` (string): The prompt to send to the model. - `prompt` (string): The prompt to send to the model.
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`. - `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`. - `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`. - `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`. - `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`. - `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`. - `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'. - `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'.
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed. - `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
- `debug` (boolean): If set to `True`, it provides debugging output in logs. - `debug` (boolean): If set to `True`, it provides debugging output in logs.
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models. Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
""" """
system_prompt: Optional[str]=None
max_new_tokens: Optional[int]=None
min_new_tokens: Optional[int]=None
temperature: Optional[int]=None
top_p: Optional[int]=None
top_k: Optional[int]=None
stop_sequences: Optional[str]=None
seed: Optional[int]=None
debug: Optional[bool]=None
def __init__(self, system_prompt: Optional[str] = None
system_prompt: Optional[str]=None, max_new_tokens: Optional[int] = None
max_new_tokens: Optional[int]=None, min_new_tokens: Optional[int] = None
min_new_tokens: Optional[int]=None, temperature: Optional[int] = None
temperature: Optional[int]=None, top_p: Optional[int] = None
top_p: Optional[int]=None, top_k: Optional[int] = None
top_k: Optional[int]=None, stop_sequences: Optional[str] = None
stop_sequences: Optional[str]=None, seed: Optional[int] = None
seed: Optional[int]=None, debug: Optional[bool] = None
debug: Optional[bool]=None) -> None:
def __init__(
self,
system_prompt: Optional[str] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
stop_sequences: Optional[str] = None,
seed: Optional[int] = None,
debug: Optional[bool] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# Function to start a prediction and get the prediction URL # Function to start a prediction and get the prediction URL
def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose): def start_prediction(
version_id, input_data, api_token, api_base, logging_obj, print_verbose
):
base_url = api_base base_url = api_base
if "deployments" in version_id: if "deployments" in version_id:
print_verbose("\nLiteLLM: Request to custom replicate deployment") print_verbose("\nLiteLLM: Request to custom replicate deployment")
@ -88,7 +107,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
headers = { headers = {
"Authorization": f"Token {api_token}", "Authorization": f"Token {api_token}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
initial_prediction_data = { initial_prediction_data = {
@ -98,24 +117,33 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input_data["prompt"], input=input_data["prompt"],
api_key="", api_key="",
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers, "api_base": base_url}, additional_args={
"complete_input_dict": initial_prediction_data,
"headers": headers,
"api_base": base_url,
},
) )
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers) response = requests.post(
f"{base_url}/predictions", json=initial_prediction_data, headers=headers
)
if response.status_code == 201: if response.status_code == 201:
response_data = response.json() response_data = response.json()
return response_data.get("urls", {}).get("get") return response_data.get("urls", {}).get("get")
else: else:
raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}") raise ReplicateError(
response.status_code, f"Failed to start prediction {response.text}"
)
# Function to handle prediction response (non-streaming) # Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose): def handle_prediction_response(prediction_url, api_token, print_verbose):
output_string = "" output_string = ""
headers = { headers = {
"Authorization": f"Token {api_token}", "Authorization": f"Token {api_token}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
status = "" status = ""
@ -127,18 +155,22 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()
if "output" in response_data: if "output" in response_data:
output_string = "".join(response_data['output']) output_string = "".join(response_data["output"])
print_verbose(f"Non-streamed output:{output_string}") print_verbose(f"Non-streamed output:{output_string}")
status = response_data.get('status', None) status = response_data.get("status", None)
logs = response_data.get("logs", "") logs = response_data.get("logs", "")
if status == "failed": if status == "failed":
replicate_error = response_data.get("error", "") replicate_error = response_data.get("error", "")
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}, \nReplicate logs:{logs}") raise ReplicateError(
status_code=400,
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
)
else: else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose("Replicate: Failed to fetch prediction status and output.") print_verbose("Replicate: Failed to fetch prediction status and output.")
return output_string, logs return output_string, logs
# Function to handle prediction response (streaming) # Function to handle prediction response (streaming)
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose): def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
previous_output = "" previous_output = ""
@ -146,30 +178,34 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
headers = { headers = {
"Authorization": f"Token {api_token}", "Authorization": f"Token {api_token}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
status = "" status = ""
while True and (status not in ["succeeded", "failed", "canceled"]): while True and (status not in ["succeeded", "failed", "canceled"]):
time.sleep(0.5) # prevent being rate limited by replicate time.sleep(0.5) # prevent being rate limited by replicate
print_verbose(f"replicate: polling endpoint: {prediction_url}") print_verbose(f"replicate: polling endpoint: {prediction_url}")
response = requests.get(prediction_url, headers=headers) response = requests.get(prediction_url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()
status = response_data['status'] status = response_data["status"]
if "output" in response_data: if "output" in response_data:
output_string = "".join(response_data['output']) output_string = "".join(response_data["output"])
new_output = output_string[len(previous_output):] new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}") print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status} yield {"output": new_output, "status": status}
previous_output = output_string previous_output = output_string
status = response_data['status'] status = response_data["status"]
if status == "failed": if status == "failed":
replicate_error = response_data.get("error", "") replicate_error = response_data.get("error", "")
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}") raise ReplicateError(
status_code=400, message=f"Error: {replicate_error}"
)
else: else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}") print_verbose(
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
)
# Function to extract version ID from model string # Function to extract version ID from model string
def model_to_version_id(model): def model_to_version_id(model):
@ -178,11 +214,12 @@ def model_to_version_id(model):
return split_model[1] return split_model[1]
return model return model
# Main function for prediction completion # Main function for prediction completion
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
api_base: str, api_base: str,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
logging_obj, logging_obj,
@ -196,35 +233,37 @@ def completion(
# Start a prediction and get the prediction URL # Start a prediction and get the prediction URL
version_id = model_to_version_id(model) version_id = model_to_version_id(model)
## Load Config ## Load Config
config = litellm.ReplicateConfig.get_config() config = litellm.ReplicateConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
system_prompt = None system_prompt = None
if optional_params is not None and "supports_system_prompt" in optional_params: if optional_params is not None and "supports_system_prompt" in optional_params:
supports_sys_prompt = optional_params.pop("supports_system_prompt") supports_sys_prompt = optional_params.pop("supports_system_prompt")
else: else:
supports_sys_prompt = False supports_sys_prompt = False
if supports_sys_prompt: if supports_sys_prompt:
for i in range(len(messages)): for i in range(len(messages)):
if messages[i]["role"] == "system": if messages[i]["role"] == "system":
first_sys_message = messages.pop(i) first_sys_message = messages.pop(i)
system_prompt = first_sys_message["content"] system_prompt = first_sys_message["content"]
break break
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}), role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""), bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""), eos_token=model_prompt_details.get("eos_token", ""),
messages=messages, messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
@ -233,43 +272,58 @@ def completion(
input_data = { input_data = {
"prompt": prompt, "prompt": prompt,
"system_prompt": system_prompt, "system_prompt": system_prompt,
**optional_params **optional_params,
} }
# Otherwise, use the prompt as is # Otherwise, use the prompt as is
else: else:
input_data = { input_data = {"prompt": prompt, **optional_params}
"prompt": prompt,
**optional_params
}
## COMPLETION CALL ## COMPLETION CALL
## Replicate Compeltion calls have 2 steps ## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url ## Step1: Start Prediction: gets a prediction url
## Step2: Poll prediction url for response ## Step2: Poll prediction url for response
## Step2: is handled with and without streaming ## Step2: is handled with and without streaming
model_response["created"] = int(time.time()) # for pricing this must remain right before calling api model_response["created"] = int(
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose) time.time()
) # for pricing this must remain right before calling api
prediction_url = start_prediction(
version_id,
input_data,
api_key,
api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
)
print_verbose(prediction_url) print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming) # Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
print_verbose("streaming request") print_verbose("streaming request")
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose) return handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
)
else: else:
result, logs = handle_prediction_response(prediction_url, api_key, print_verbose) result, logs = handle_prediction_response(
model_response["ended"] = time.time() # for pricing this must remain right after calling api prediction_url, api_key, print_verbose
)
model_response[
"ended"
] = time.time() # for pricing this must remain right after calling api
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key="", api_key="",
original_response=result, original_response=result,
additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, }, additional_args={
"complete_input_dict": input_data,
"logs": logs,
"api_base": prediction_url,
},
) )
print_verbose(f"raw model_response: {result}") print_verbose(f"raw model_response: {result}")
if len(result) == 0: # edge case, where result from replicate is empty if len(result) == 0: # edge case, where result from replicate is empty
result = " " result = " "
## Building RESPONSE OBJECT ## Building RESPONSE OBJECT
@ -278,12 +332,14 @@ def completion(
# Calculate usage # Calculate usage
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"].get("content", ""))) completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["model"] = "replicate/" + model model_response["model"] = "replicate/" + model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response

View file

@ -11,42 +11,61 @@ from copy import deepcopy
import httpx import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class SagemakerError(Exception): class SagemakerError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker") self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class SagemakerConfig():
class SagemakerConfig:
""" """
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
""" """
max_new_tokens: Optional[int]=None
top_p: Optional[float]=None
temperature: Optional[float]=None
return_full_text: Optional[bool]=None
def __init__(self, max_new_tokens: Optional[int] = None
max_new_tokens: Optional[int]=None, top_p: Optional[float] = None
top_p: Optional[float]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, return_full_text: Optional[bool] = None
return_full_text: Optional[bool]=None) -> None:
def __init__(
self,
max_new_tokens: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
return_full_text: Optional[bool] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
""" """
SAGEMAKER AUTH Keys/Vars SAGEMAKER AUTH Keys/Vars
os.environ['AWS_ACCESS_KEY_ID'] = "" os.environ['AWS_ACCESS_KEY_ID'] = ""
@ -55,6 +74,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
# set os.environ['AWS_REGION_NAME'] = <your-region_name> # set os.environ['AWS_REGION_NAME'] = <your-region_name>
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -85,28 +105,30 @@ def completion(
region_name=aws_region_name, region_name=aws_region_name,
) )
else: else:
# aws_access_key_id is None, assume user is trying to auth using env variables # aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables # boto3 automaticaly reads env variables
# we need to read region name from env # we need to read region name from env
# I assume majority of users use .env for auth # I assume majority of users use .env for auth
region_name = ( region_name = (
get_secret("AWS_REGION_NAME") or get_secret("AWS_REGION_NAME")
"us-west-2" # default to us-west-2 if user not specified or "us-west-2" # default to us-west-2 if user not specified
) )
client = boto3.client( client = boto3.client(
service_name="sagemaker-runtime", service_name="sagemaker-runtime",
region_name=region_name, region_name=region_name,
) )
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params) inference_params = deepcopy(optional_params)
inference_params.pop("stream", None) inference_params.pop("stream", None)
## Load Config ## Load Config
config = litellm.SagemakerConfig.get_config() config = litellm.SagemakerConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
model = model model = model
@ -114,25 +136,26 @@ def completion(
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None), role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages messages=messages,
) )
else: else:
if hf_model_name is None: if hf_model_name is None:
if "llama-2" in model.lower(): # llama-2 model if "llama-2" in model.lower(): # llama-2 model
if "chat" in model.lower(): # apply llama2 chat template if "chat" in model.lower(): # apply llama2 chat template
hf_model_name = "meta-llama/Llama-2-7b-chat-hf" hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b" hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages) prompt = prompt_factory(model=hf_model_name, messages=messages)
data = json.dumps({ data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
"inputs": prompt, "utf-8"
"parameters": inference_params )
}).encode('utf-8')
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
@ -142,31 +165,35 @@ def completion(
Body={data}, Body={data},
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
) )
""" # type: ignore """ # type: ignore
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name}, additional_args={
) "complete_input_dict": data,
"request_str": request_str,
"hf_model_name": hf_model_name,
},
)
## COMPLETION CALL ## COMPLETION CALL
try: try:
response = client.invoke_endpoint( response = client.invoke_endpoint(
EndpointName=model, EndpointName=model,
ContentType="application/json", ContentType="application/json",
Body=data, Body=data,
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
) )
except Exception as e: except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}") raise SagemakerError(status_code=500, message=f"{str(e)}")
response = response["Body"].read().decode("utf8") response = response["Body"].read().decode("utf8")
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key="", api_key="",
original_response=response, original_response=response,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response}") print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = json.loads(response) completion_response = json.loads(response)
@ -177,19 +204,20 @@ def completion(
completion_output += completion_response_choices["generation"] completion_output += completion_response_choices["generation"]
elif "generated_text" in completion_response_choices: elif "generated_text" in completion_response_choices:
completion_output += completion_response_choices["generated_text"] completion_output += completion_response_choices["generated_text"]
# check if the prompt template is part of output, if so - filter it out # check if the prompt template is part of output, if so - filter it out
if completion_output.startswith(prompt) and "<s>" in prompt: if completion_output.startswith(prompt) and "<s>" in prompt:
completion_output = completion_output.replace(prompt, "", 1) completion_output = completion_output.replace(prompt, "", 1)
model_response["choices"][0]["message"]["content"] = completion_output model_response["choices"][0]["message"]["content"] = completion_output
except: except:
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500) raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -197,28 +225,32 @@ def completion(
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(model: str,
input: list, def embedding(
model_response: EmbeddingResponse, model: str,
print_verbose: Callable, input: list,
encoding, model_response: EmbeddingResponse,
logging_obj, print_verbose: Callable,
custom_prompt_dict={}, encoding,
optional_params=None, logging_obj,
litellm_params=None, custom_prompt_dict={},
logger_fn=None): optional_params=None,
litellm_params=None,
logger_fn=None,
):
""" """
Supports Huggingface Jumpstart embeddings like GPT-6B Supports Huggingface Jumpstart embeddings like GPT-6B
""" """
### BOTO3 INIT ### BOTO3 INIT
import boto3 import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
@ -234,34 +266,34 @@ def embedding(model: str,
region_name=aws_region_name, region_name=aws_region_name,
) )
else: else:
# aws_access_key_id is None, assume user is trying to auth using env variables # aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables # boto3 automaticaly reads env variables
# we need to read region name from env # we need to read region name from env
# I assume majority of users use .env for auth # I assume majority of users use .env for auth
region_name = ( region_name = (
get_secret("AWS_REGION_NAME") or get_secret("AWS_REGION_NAME")
"us-west-2" # default to us-west-2 if user not specified or "us-west-2" # default to us-west-2 if user not specified
) )
client = boto3.client( client = boto3.client(
service_name="sagemaker-runtime", service_name="sagemaker-runtime",
region_name=region_name, region_name=region_name,
) )
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params) inference_params = deepcopy(optional_params)
inference_params.pop("stream", None) inference_params.pop("stream", None)
## Load Config ## Load Config
config = litellm.SagemakerConfig.get_config() config = litellm.SagemakerConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
#### HF EMBEDDING LOGIC #### HF EMBEDDING LOGIC
data = json.dumps({ data = json.dumps({"text_inputs": input}).encode("utf-8")
"text_inputs": input
}).encode('utf-8')
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
@ -270,67 +302,65 @@ def embedding(model: str,
ContentType="application/json", ContentType="application/json",
Body={data}, Body={data},
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
)""" # type: ignore )""" # type: ignore
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key="", api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str}, additional_args={"complete_input_dict": data, "request_str": request_str},
) )
## EMBEDDING CALL ## EMBEDDING CALL
try: try:
response = client.invoke_endpoint( response = client.invoke_endpoint(
EndpointName=model, EndpointName=model,
ContentType="application/json", ContentType="application/json",
Body=data, Body=data,
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
) )
except Exception as e: except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}") raise SagemakerError(status_code=500, message=f"{str(e)}")
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key="", api_key="",
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=response, original_response=response,
) )
response = json.loads(response["Body"].read().decode("utf8")) response = json.loads(response["Body"].read().decode("utf8"))
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
api_key="", api_key="",
original_response=response, original_response=response,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response}") print_verbose(f"raw model_response: {response}")
if "embedding" not in response: if "embedding" not in response:
raise SagemakerError(status_code=500, message="embedding not found in response") raise SagemakerError(status_code=500, message="embedding not found in response")
embeddings = response['embedding'] embeddings = response["embedding"]
if not isinstance(embeddings, list): if not isinstance(embeddings, list):
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}") raise SagemakerError(
status_code=422, message=f"Response not in expected format - {embeddings}"
)
output_data = [] output_data = []
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
output_data.append( output_data.append(
{ {"object": "embedding", "index": idx, "embedding": embedding}
"object": "embedding",
"index": idx,
"embedding": embedding
}
) )
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = output_data model_response["data"] = output_data
model_response["model"] = model model_response["model"] = model
input_tokens = 0 input_tokens = 0
for text in input: for text in input:
input_tokens+=len(encoding.encode(text)) input_tokens += len(encoding.encode(text))
model_response["usage"] = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens)
return model_response return model_response

View file

@ -9,17 +9,21 @@ import httpx
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class TogetherAIError(Exception): class TogetherAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.together.xyz/inference") self.request = httpx.Request(
method="POST", url="https://api.together.xyz/inference"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class TogetherAIConfig():
class TogetherAIConfig:
""" """
Reference: https://docs.together.ai/reference/inference Reference: https://docs.together.ai/reference/inference
@ -37,35 +41,49 @@ class TogetherAIConfig():
- `repetition_penalty` (float, optional): A number that controls the diversity of generated text by reducing the likelihood of repeated sequences. Higher values decrease repetition. - `repetition_penalty` (float, optional): A number that controls the diversity of generated text by reducing the likelihood of repeated sequences. Higher values decrease repetition.
- `logprobs` (int32, optional): This parameter is not described in the prompt. - `logprobs` (int32, optional): This parameter is not described in the prompt.
""" """
max_tokens: Optional[int]=None
stop: Optional[str]=None max_tokens: Optional[int] = None
temperature:Optional[int]=None stop: Optional[str] = None
top_p: Optional[float]=None temperature: Optional[int] = None
top_k: Optional[int]=None top_p: Optional[float] = None
repetition_penalty: Optional[float]=None top_k: Optional[int] = None
logprobs: Optional[int]=None repetition_penalty: Optional[float] = None
logprobs: Optional[int] = None
def __init__(self,
max_tokens: Optional[int]=None, def __init__(
stop: Optional[str]=None, self,
temperature:Optional[int]=None, max_tokens: Optional[int] = None,
top_p: Optional[float]=None, stop: Optional[str] = None,
top_k: Optional[int]=None, temperature: Optional[int] = None,
repetition_penalty: Optional[float]=None, top_p: Optional[float] = None,
logprobs: Optional[int]=None) -> None: top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
logprobs: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
@ -80,10 +98,11 @@ def validate_environment(api_key):
} }
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
api_base: str, api_base: str,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
@ -97,9 +116,11 @@ def completion(
headers = validate_environment(api_key) headers = validate_environment(api_key)
## Load Config ## Load Config
config = litellm.TogetherAIConfig.get_config() config = litellm.TogetherAIConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}") print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
@ -107,15 +128,20 @@ def completion(
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}), role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""), bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""), eos_token=model_prompt_details.get("eos_token", ""),
messages=messages, messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list prompt = prompt_factory(
model=model,
messages=messages,
api_key=api_key,
custom_llm_provider="together_ai",
) # api key required to query together ai model list
data = { data = {
"model": model, "model": model,
@ -128,13 +154,14 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": api_base}, additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
) )
## COMPLETION CALL ## COMPLETION CALL
if ( if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
"stream_tokens" in optional_params
and optional_params["stream_tokens"] == True
):
response = requests.post( response = requests.post(
api_base, api_base,
headers=headers, headers=headers,
@ -143,18 +170,14 @@ def completion(
) )
return response.iter_lines() return response.iter_lines()
else: else:
response = requests.post( response = requests.post(api_base, headers=headers, data=json.dumps(data))
api_base,
headers=headers,
data=json.dumps(data)
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
if response.status_code != 200: if response.status_code != 200:
@ -170,30 +193,38 @@ def completion(
) )
elif "error" in completion_response["output"]: elif "error" in completion_response["output"]:
raise TogetherAIError( raise TogetherAIError(
message=json.dumps(completion_response["output"]), status_code=response.status_code message=json.dumps(completion_response["output"]),
status_code=response.status_code,
) )
if len(completion_response["output"]["choices"][0]["text"]) >= 0: if len(completion_response["output"]["choices"][0]["text"]) >= 0:
model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"] model_response["choices"][0]["message"]["content"] = completion_response[
"output"
]["choices"][0]["text"]
## CALCULATING USAGE ## CALCULATING USAGE
print_verbose(f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}") print_verbose(
f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
)
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
if "finish_reason" in completion_response["output"]["choices"][0]: if "finish_reason" in completion_response["output"]["choices"][0]:
model_response.choices[0].finish_reason = completion_response["output"]["choices"][0]["finish_reason"] model_response.choices[0].finish_reason = completion_response["output"][
"choices"
][0]["finish_reason"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm import litellm
import httpx import httpx
class VertexAIError(Exception): class VertexAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url=" https://cloud.google.com/vertex-ai/") self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class VertexAIConfig():
class VertexAIConfig:
""" """
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
@ -34,28 +38,42 @@ class VertexAIConfig():
Note: Please make sure to modify the default parameters as required for your use case. Note: Please make sure to modify the default parameters as required for your use case.
""" """
temperature: Optional[float]=None
max_output_tokens: Optional[int]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
def __init__(self, temperature: Optional[float] = None
temperature: Optional[float]=None, max_output_tokens: Optional[int] = None
max_output_tokens: Optional[int]=None, top_p: Optional[float] = None
top_p: Optional[float]=None, top_k: Optional[int] = None
top_k: Optional[int]=None) -> None:
def __init__(
self,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def _get_image_bytes_from_url(image_url: str) -> bytes: def _get_image_bytes_from_url(image_url: str) -> bytes:
try: try:
@ -65,7 +83,7 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
return image_bytes return image_bytes
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
# Handle any request exceptions (e.g., connection error, timeout) # Handle any request exceptions (e.g., connection error, timeout)
return b'' # Return an empty bytes object or handle the error as needed return b"" # Return an empty bytes object or handle the error as needed
def _load_image_from_url(image_url: str): def _load_image_from_url(image_url: str):
@ -78,13 +96,18 @@ def _load_image_from_url(image_url: str):
Returns: Returns:
Image: The loaded image. Image: The loaded image.
""" """
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
image_bytes = _get_image_bytes_from_url(image_url) image_bytes = _get_image_bytes_from_url(image_url)
return Image.from_bytes(image_bytes) return Image.from_bytes(image_bytes)
def _gemini_vision_convert_messages(
messages: list def _gemini_vision_convert_messages(messages: list):
):
""" """
Converts given messages for GPT-4 Vision to Gemini format. Converts given messages for GPT-4 Vision to Gemini format.
@ -95,7 +118,7 @@ def _gemini_vision_convert_messages(
Returns: Returns:
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images). tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
Raises: Raises:
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed. VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
Exception: If any other exception occurs during the execution of the function. Exception: If any other exception occurs during the execution of the function.
@ -115,11 +138,23 @@ def _gemini_vision_convert_messages(
try: try:
import vertexai import vertexai
except: except:
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`") raise VertexAIError(
try: status_code=400,
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
try:
from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
InputOutputTextPair,
)
from vertexai.language_models import TextGenerationModel, CodeGenerationModel from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
# given messages for gpt-4 vision, convert them for gemini # given messages for gpt-4 vision, convert them for gemini
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb # https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
@ -159,6 +194,7 @@ def _gemini_vision_convert_messages(
except Exception as e: except Exception as e:
raise e raise e
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -171,30 +207,38 @@ def completion(
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
acompletion: bool=False acompletion: bool = False,
): ):
try: try:
import vertexai import vertexai
except: except:
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`") raise VertexAIError(
try: status_code=400,
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
try:
from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
InputOutputTextPair,
)
from vertexai.language_models import TextGenerationModel, CodeGenerationModel from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
)
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
vertexai.init(project=vertex_project, location=vertex_location)
vertexai.init(
project=vertex_project, location=vertex_location
)
## Load Config ## Load Config
config = litellm.VertexAIConfig.get_config() config = litellm.VertexAIConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: if k not in optional_params:
optional_params[k] = v optional_params[k] = v
## Process safety settings into format expected by vertex AI ## Process safety settings into format expected by vertex AI
safety_settings = None safety_settings = None
if "safety_settings" in optional_params: if "safety_settings" in optional_params:
safety_settings = optional_params.pop("safety_settings") safety_settings = optional_params.pop("safety_settings")
@ -202,17 +246,25 @@ def completion(
raise ValueError("safety_settings must be a list") raise ValueError("safety_settings must be a list")
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
raise ValueError("safety_settings must be a list of dicts") raise ValueError("safety_settings must be a list of dicts")
safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings] safety_settings = [
gapic_content_types.SafetySetting(x) for x in safety_settings
]
# vertexai does not use an API key, it looks for credentials.json in the environment # vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)]) prompt = " ".join(
[
message["content"]
for message in messages
if isinstance(message["content"], str)
]
)
mode = "" mode = ""
request_str = "" request_str = ""
response_obj = None response_obj = None
if model in litellm.vertex_language_models: if model in litellm.vertex_language_models:
llm_model = GenerativeModel(model) llm_model = GenerativeModel(model)
mode = "" mode = ""
request_str += f"llm_model = GenerativeModel({model})\n" request_str += f"llm_model = GenerativeModel({model})\n"
@ -232,31 +284,76 @@ def completion(
llm_model = CodeGenerationModel.from_pretrained(model) llm_model = CodeGenerationModel.from_pretrained(model)
mode = "text" mode = "text"
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
else: # vertex_code_llm_models else: # vertex_code_llm_models
llm_model = CodeChatModel.from_pretrained(model) llm_model = CodeChatModel.from_pretrained(model)
mode = "chat" mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
if acompletion == True: # [TODO] expand support to vertex ai chat + text models if acompletion == True: # [TODO] expand support to vertex ai chat + text models
if optional_params.get("stream", False) is True: if optional_params.get("stream", False) is True:
# async streaming # async streaming
return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, messages=messages, print_verbose=print_verbose, **optional_params) return async_streaming(
return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, messages=messages,print_verbose=print_verbose,**optional_params) llm_model=llm_model,
mode=mode,
prompt=prompt,
logging_obj=logging_obj,
request_str=request_str,
model=model,
model_response=model_response,
messages=messages,
print_verbose=print_verbose,
**optional_params,
)
return async_completion(
llm_model=llm_model,
mode=mode,
prompt=prompt,
logging_obj=logging_obj,
request_str=request_str,
model=model,
model_response=model_response,
encoding=encoding,
messages=messages,
print_verbose=print_verbose,
**optional_params,
)
if mode == "": if mode == "":
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
model_response = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream) input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=stream,
)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n" request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
response_obj = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings) input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
)
completion_response = response_obj.text completion_response = response_obj.text
response_obj = response_obj._raw_response response_obj = response_obj._raw_response
elif mode == "vision": elif mode == "vision":
@ -268,21 +365,35 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content( model_response = llm_model.generate_content(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings, safety_settings=safety_settings,
stream=True stream=True,
) )
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"response = llm_model.generate_content({content})\n" request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
## LLM Call ## LLM Call
response = llm_model.generate_content( response = llm_model.generate_content(
contents=content, contents=content,
@ -293,88 +404,150 @@ def completion(
response_obj = response._raw_response response_obj = response._raw_response
elif mode == "chat": elif mode == "chat":
chat = llm_model.start_chat() chat = llm_model.start_chat()
request_str+= f"chat = llm_model.start_chat()\n" request_str += f"chat = llm_model.start_chat()\n"
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error, # NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request # we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params optional_params.pop(
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n" "stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming({prompt}, **{optional_params})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = chat.send_message_streaming(prompt, **optional_params) model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = chat.send_message(prompt, **optional_params).text completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text": elif mode == "text":
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai optional_params.pop(
request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n" "stream", None
) # See note above on handling streaming for vertex ai
request_str += (
f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.predict_streaming(prompt, **optional_params) model_response = llm_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = llm_model.predict(prompt, **optional_params).text completion_response = llm_model.predict(prompt, **optional_params).text
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response input=prompt, api_key=None, original_response=completion_response
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
if len(str(completion_response)) > 0: if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = str(
"content" completion_response
] = str(completion_response) )
model_response["choices"][0]["message"]["content"] = str(completion_response) model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None: if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name model_response["choices"][0].finish_reason = response_obj.candidates[
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count, 0
completion_tokens=response_obj.usage_metadata.candidates_token_count, ].finish_reason.name
total_tokens=response_obj.usage_metadata.total_token_count) usage = Usage(
else: prompt_tokens=response_obj.usage_metadata.prompt_token_count,
prompt_tokens = len( completion_tokens=response_obj.usage_metadata.candidates_token_count,
encoding.encode(prompt) total_tokens=response_obj.usage_metadata.total_token_count,
) )
else:
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) )
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, messages = None, print_verbose = None, **optional_params):
async def async_completion(
llm_model,
mode: str,
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
encoding=None,
messages=None,
print_verbose=None,
**optional_params,
):
""" """
Add support for acompletion calls for gemini-pro Add support for acompletion calls for gemini-pro
""" """
try: try:
from vertexai.preview.generative_models import GenerationConfig from vertexai.preview.generative_models import GenerationConfig
if mode == "": if mode == "":
# gemini-pro # gemini-pro
chat = llm_model.start_chat() chat = llm_model.start_chat()
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params)) input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params)
)
completion_response = response_obj.text completion_response = response_obj.text
response_obj = response_obj._raw_response response_obj = response_obj._raw_response
elif mode == "vision": elif mode == "vision":
@ -386,12 +559,18 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
request_str += f"response = llm_model.generate_content({content})\n" request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
## LLM Call ## LLM Call
response = await llm_model._generate_content_async( response = await llm_model._generate_content_async(
contents=content, contents=content, generation_config=GenerationConfig(**optional_params)
generation_config=GenerationConfig(**optional_params)
) )
completion_response = response.text completion_response = response.text
response_obj = response._raw_response response_obj = response._raw_response
@ -399,14 +578,28 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
# chat-bison etc. # chat-bison etc.
chat = llm_model.start_chat() chat = llm_model.start_chat()
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(prompt, **optional_params) response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text completion_response = response_obj.text
elif mode == "text": elif mode == "text":
# gecko etc. # gecko etc.
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await llm_model.predict_async(prompt, **optional_params) response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text completion_response = response_obj.text
@ -416,51 +609,77 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
if len(str(completion_response)) > 0: if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = str(
"content" completion_response
] = str(completion_response) )
model_response["choices"][0]["message"]["content"] = str(completion_response) model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None: if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name model_response["choices"][0].finish_reason = response_obj.candidates[
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count, 0
completion_tokens=response_obj.usage_metadata.candidates_token_count, ].finish_reason.name
total_tokens=response_obj.usage_metadata.total_token_count) usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count,
)
else: else:
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) )
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, messages = None, print_verbose = None, **optional_params):
async def async_streaming(
llm_model,
mode: str,
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
messages=None,
print_verbose=None,
**optional_params,
):
""" """
Add support for async streaming calls for gemini-pro Add support for async streaming calls for gemini-pro
""" """
from vertexai.preview.generative_models import GenerationConfig from vertexai.preview.generative_models import GenerationConfig
if mode == "":
if mode == "":
# gemini-pro # gemini-pro
chat = llm_model.start_chat() chat = llm_model.start_chat()
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream) input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params), stream=stream
)
optional_params["stream"] = True optional_params["stream"] = True
elif mode == "vision": elif mode == "vision":
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
@ -470,33 +689,68 @@ async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_r
content = [prompt] + images content = [prompt] + images
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = llm_model._generate_content_streaming_async( response = llm_model._generate_content_streaming_async(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=GenerationConfig(**optional_params),
stream=True stream=True,
) )
optional_params["stream"] = True optional_params["stream"] = True
elif mode == "chat": elif mode == "chat":
chat = llm_model.start_chat() chat = llm_model.start_chat()
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params optional_params.pop(
request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n" "stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = chat.send_message_streaming_async(prompt, **optional_params) response = chat.send_message_streaming_async(prompt, **optional_params)
optional_params["stream"] = True optional_params["stream"] = True
elif mode == "text": elif mode == "text":
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai optional_params.pop(
request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n" "stream", None
) # See note above on handling streaming for vertex ai
request_str += (
f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = llm_model.predict_streaming_async(prompt, **optional_params) response = llm_model.predict_streaming_async(prompt, **optional_params)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="vertex_ai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -6,7 +6,10 @@ import time, httpx
from typing import Callable, Any from typing import Callable, Any
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
llm = None llm = None
class VLLMError(Exception): class VLLMError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -17,17 +20,20 @@ class VLLMError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# check if vllm is installed # check if vllm is installed
def validate_environment(model: str): def validate_environment(model: str):
global llm global llm
try: try:
from vllm import LLM, SamplingParams # type: ignore from vllm import LLM, SamplingParams # type: ignore
if llm is None: if llm is None:
llm = LLM(model=model) llm = LLM(model=model)
return llm, SamplingParams return llm, SamplingParams
except Exception as e: except Exception as e:
raise VLLMError(status_code=0, message=str(e)) raise VLLMError(status_code=0, message=str(e))
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -50,15 +56,14 @@ def completion(
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -69,9 +74,10 @@ def completion(
if llm: if llm:
outputs = llm.generate(prompt, sampling_params) outputs = llm.generate(prompt, sampling_params)
else: else:
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm") raise VLLMError(
status_code=0, message="Need to pass in a model name to initialize vllm"
)
## COMPLETION CALL ## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return iter(outputs) return iter(outputs)
@ -88,24 +94,22 @@ def completion(
model_response["choices"][0]["message"]["content"] = outputs[0].outputs[0].text model_response["choices"][0]["message"]["content"] = outputs[0].outputs[0].text
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = len(outputs[0].prompt_token_ids) prompt_tokens = len(outputs[0].prompt_token_ids)
completion_tokens = len(outputs[0].outputs[0].token_ids) completion_tokens = len(outputs[0].outputs[0].token_ids)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def batch_completions( def batch_completions(
model: str, model: str, messages: list, optional_params=None, custom_prompt_dict={}
messages: list,
optional_params=None,
custom_prompt_dict={}
): ):
""" """
Example usage: Example usage:
@ -137,31 +141,33 @@ def batch_completions(
except Exception as e: except Exception as e:
error_str = str(e) error_str = str(e)
if "data parallel group is already initialized" in error_str: if "data parallel group is already initialized" in error_str:
pass pass
else: else:
raise VLLMError(status_code=0, message=error_str) raise VLLMError(status_code=0, message=error_str)
sampling_params = SamplingParams(**optional_params) sampling_params = SamplingParams(**optional_params)
prompts = [] prompts = []
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
for message in messages: for message in messages:
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=message messages=message,
) )
prompts.append(prompt) prompts.append(prompt)
else: else:
for message in messages: for message in messages:
prompt = prompt_factory(model=model, messages=message) prompt = prompt_factory(model=model, messages=message)
prompts.append(prompt) prompts.append(prompt)
if llm: if llm:
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
else: else:
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm") raise VLLMError(
status_code=0, message="Need to pass in a model name to initialize vllm"
)
final_outputs = [] final_outputs = []
for output in outputs: for output in outputs:
@ -170,20 +176,21 @@ def batch_completions(
model_response["choices"][0]["message"]["content"] = output.outputs[0].text model_response["choices"][0]["message"]["content"] = output.outputs[0].text
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = len(output.prompt_token_ids) prompt_tokens = len(output.prompt_token_ids)
completion_tokens = len(output.outputs[0].token_ids) completion_tokens = len(output.outputs[0].token_ids)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
final_outputs.append(model_response) final_outputs.append(model_response)
return final_outputs return final_outputs
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

File diff suppressed because it is too large Load diff

View file

@ -1 +1 @@
from . import * from . import *

View file

@ -1,4 +1,4 @@
def my_custom_rule(input): # receives the model response def my_custom_rule(input): # receives the model response
# if len(input) < 5: # trigger fallback if the model response is too short # if len(input) < 5: # trigger fallback if the model response is too short
return False return False
return True return True

View file

@ -3,13 +3,15 @@ from typing import Optional, List, Union, Dict, Literal
from datetime import datetime from datetime import datetime
import uuid, json import uuid, json
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
""" """
Implements default functions, all pydantic objects should have. Implements default functions, all pydantic objects should have.
""" """
def json(self, **kwargs): def json(self, **kwargs):
try: try:
return self.model_dump() # noqa return self.model_dump() # noqa
except: except:
# if using pydantic v1 # if using pydantic v1
return self.dict() return self.dict()
@ -34,7 +36,7 @@ class ProxyChatCompletionRequest(LiteLLMBase):
tools: Optional[List[str]] = None tools: Optional[List[str]] = None
tool_choice: Optional[str] = None tool_choice: Optional[str] = None
functions: Optional[List[str]] = None # soon to be deprecated functions: Optional[List[str]] = None # soon to be deprecated
function_call: Optional[str] = None # soon to be deprecated function_call: Optional[str] = None # soon to be deprecated
# Optional LiteLLM params # Optional LiteLLM params
caching: Optional[bool] = None caching: Optional[bool] = None
@ -49,7 +51,8 @@ class ProxyChatCompletionRequest(LiteLLMBase):
request_timeout: Optional[int] = None request_timeout: Optional[int] = None
class Config: class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(LiteLLMBase): class ModelInfoDelete(LiteLLMBase):
id: Optional[str] id: Optional[str]
@ -57,38 +60,37 @@ class ModelInfoDelete(LiteLLMBase):
class ModelInfo(LiteLLMBase): class ModelInfo(LiteLLMBase):
id: Optional[str] id: Optional[str]
mode: Optional[Literal['embedding', 'chat', 'completion']] mode: Optional[Literal["embedding", "chat", "completion"]]
input_cost_per_token: Optional[float] = 0.0 input_cost_per_token: Optional[float] = 0.0
output_cost_per_token: Optional[float] = 0.0 output_cost_per_token: Optional[float] = 0.0
max_tokens: Optional[int] = 2048 # assume 2048 if not set max_tokens: Optional[int] = 2048 # assume 2048 if not set
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model # for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
# we look up the base model in model_prices_and_context_window.json # we look up the base model in model_prices_and_context_window.json
base_model: Optional[Literal base_model: Optional[
[ Literal[
'gpt-4-1106-preview', "gpt-4-1106-preview",
'gpt-4-32k', "gpt-4-32k",
'gpt-4', "gpt-4",
'gpt-3.5-turbo-16k', "gpt-3.5-turbo-16k",
'gpt-3.5-turbo', "gpt-3.5-turbo",
'text-embedding-ada-002', "text-embedding-ada-002",
] ]
] ]
class Config: class Config:
extra = Extra.allow # Allow extra fields extra = Extra.allow # Allow extra fields
protected_namespaces = () protected_namespaces = ()
@root_validator(pre=True) @root_validator(pre=True)
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("id") is None: if values.get("id") is None:
values.update({"id": str(uuid.uuid4())}) values.update({"id": str(uuid.uuid4())})
if values.get("mode") is None: if values.get("mode") is None:
values.update({"mode": None}) values.update({"mode": None})
if values.get("input_cost_per_token") is None: if values.get("input_cost_per_token") is None:
values.update({"input_cost_per_token": None}) values.update({"input_cost_per_token": None})
if values.get("output_cost_per_token") is None: if values.get("output_cost_per_token") is None:
values.update({"output_cost_per_token": None}) values.update({"output_cost_per_token": None})
if values.get("max_tokens") is None: if values.get("max_tokens") is None:
values.update({"max_tokens": None}) values.update({"max_tokens": None})
@ -97,21 +99,21 @@ class ModelInfo(LiteLLMBase):
return values return values
class ModelParams(LiteLLMBase): class ModelParams(LiteLLMBase):
model_name: str model_name: str
litellm_params: dict litellm_params: dict
model_info: ModelInfo model_info: ModelInfo
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
@root_validator(pre=True) @root_validator(pre=True)
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("model_info") is None: if values.get("model_info") is None:
values.update({"model_info": ModelInfo()}) values.update({"model_info": ModelInfo()})
return values return values
class GenerateKeyRequest(LiteLLMBase): class GenerateKeyRequest(LiteLLMBase):
duration: Optional[str] = "1h" duration: Optional[str] = "1h"
models: Optional[list] = [] models: Optional[list] = []
@ -122,6 +124,7 @@ class GenerateKeyRequest(LiteLLMBase):
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {} metadata: Optional[dict] = {}
class UpdateKeyRequest(LiteLLMBase): class UpdateKeyRequest(LiteLLMBase):
key: str key: str
duration: Optional[str] = None duration: Optional[str] = None
@ -133,10 +136,12 @@ class UpdateKeyRequest(LiteLLMBase):
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {} metadata: Optional[dict] = {}
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
""" """
Return the row in the db Return the row in the db
""" """
api_key: Optional[str] = None api_key: Optional[str] = None
models: list = [] models: list = []
aliases: dict = {} aliases: dict = {}
@ -147,45 +152,84 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
duration: str = "1h" duration: str = "1h"
metadata: dict = {} metadata: dict = {}
class GenerateKeyResponse(LiteLLMBase): class GenerateKeyResponse(LiteLLMBase):
key: str key: str
expires: Optional[datetime] expires: Optional[datetime]
user_id: str user_id: str
class _DeleteKeyObject(LiteLLMBase): class _DeleteKeyObject(LiteLLMBase):
key: str key: str
class DeleteKeyRequest(LiteLLMBase): class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject] keys: List[_DeleteKeyObject]
class NewUserRequest(GenerateKeyRequest): class NewUserRequest(GenerateKeyRequest):
max_budget: Optional[float] = None max_budget: Optional[float] = None
class NewUserResponse(GenerateKeyResponse): class NewUserResponse(GenerateKeyResponse):
max_budget: Optional[float] = None max_budget: Optional[float] = None
class ConfigGeneralSettings(LiteLLMBase): class ConfigGeneralSettings(LiteLLMBase):
""" """
Documents all the fields supported by `general_settings` in config.yaml Documents all the fields supported by `general_settings` in config.yaml
""" """
completion_model: Optional[str] = Field(None, description="proxy level default model for all chat completion calls")
use_azure_key_vault: Optional[bool] = Field(None, description="load keys from azure key vault") completion_model: Optional[str] = Field(
master_key: Optional[str] = Field(None, description="require a key for all calls to proxy") None, description="proxy level default model for all chat completion calls"
database_url: Optional[str] = Field(None, description="connect to a postgres db - needed for generating temporary keys + tracking spend / key") )
otel: Optional[bool] = Field(None, description="[BETA] OpenTelemetry support - this might change, use with caution.") use_azure_key_vault: Optional[bool] = Field(
custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth") None, description="load keys from azure key vault"
max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key") )
infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)") master_key: Optional[str] = Field(
background_health_checks: Optional[bool] = Field(None, description="run health checks in background") None, description="require a key for all calls to proxy"
health_check_interval: int = Field(300, description="background health check interval in seconds") )
database_url: Optional[str] = Field(
None,
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
)
otel: Optional[bool] = Field(
None,
description="[BETA] OpenTelemetry support - this might change, use with caution.",
)
custom_auth: Optional[str] = Field(
None,
description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth",
)
max_parallel_requests: Optional[int] = Field(
None, description="maximum parallel requests for each api key"
)
infer_model_from_keys: Optional[bool] = Field(
None,
description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)",
)
background_health_checks: Optional[bool] = Field(
None, description="run health checks in background"
)
health_check_interval: int = Field(
300, description="background health check interval in seconds"
)
class ConfigYAML(LiteLLMBase): class ConfigYAML(LiteLLMBase):
""" """
Documents all the fields supported by the config.yaml Documents all the fields supported by the config.yaml
""" """
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache") model_list: Optional[List[ModelParams]] = Field(
None,
description="List of supported models on the server, with model-specific configs",
)
litellm_settings: Optional[dict] = Field(
None,
description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache",
)
general_settings: Optional[ConfigGeneralSettings] = None general_settings: Optional[ConfigGeneralSettings] = None
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

View file

@ -1,14 +1,16 @@
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from fastapi import Request from fastapi import Request
from dotenv import load_dotenv from dotenv import load_dotenv
import os import os
load_dotenv() load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234" modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"
if api_key == modified_master_key: if api_key == modified_master_key:
return UserAPIKeyAuth(api_key=api_key) return UserAPIKeyAuth(api_key=api_key)
raise Exception raise Exception
except: except:
raise Exception raise Exception

View file

@ -4,17 +4,19 @@ import sys, os, traceback
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
import litellm import litellm
import inspect import inspect
# This file includes the custom callbacks for LiteLLM Proxy # This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml # Once defined, these can be passed in proxy_config.yaml
def print_verbose(print_statement): def print_verbose(print_statement):
if litellm.set_verbose: if litellm.set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
class MyCustomHandler(CustomLogger): class MyCustomHandler(CustomLogger):
def __init__(self): def __init__(self):
@ -23,36 +25,38 @@ class MyCustomHandler(CustomLogger):
print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger") print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger")
try: try:
print_verbose(f"Logger Initialized with following methods:") print_verbose(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))] methods = [
method
for method in dir(self)
if inspect.ismethod(getattr(self, method))
]
# Pretty print_verbose the methods # Pretty print_verbose the methods
for method in methods: for method in methods:
print_verbose(f" - {method}") print_verbose(f" - {method}")
print_verbose(f"{reset_color_code}") print_verbose(f"{reset_color_code}")
except: except:
pass pass
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
print_verbose(f"Pre-API Call") print_verbose(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"Post-API Call") print_verbose(f"Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time): def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Stream") print_verbose(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose("On Success!") print_verbose("On Success!")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Async Success!") print_verbose(f"On Async Success!")
response_cost = litellm.completion_cost(completion_response=response_obj) response_cost = litellm.completion_cost(completion_response=response_obj)
assert response_cost > 0.0 assert response_cost > 0.0
return return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try: try:
print_verbose(f"On Async Failure !") print_verbose(f"On Async Failure !")
except Exception as e: except Exception as e:
@ -64,4 +68,4 @@ proxy_handler_instance = MyCustomHandler()
# need to set litellm.callbacks = [customHandler] # on the proxy # need to set litellm.callbacks = [customHandler] # on the proxy
# litellm.success_callback = [async_on_succes_logger] # litellm.success_callback = [async_on_succes_logger]

View file

@ -12,25 +12,16 @@ from litellm._logging import print_verbose
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ILLEGAL_DISPLAY_PARAMS = [ ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key"]
"messages",
"api_key"
]
def _get_random_llm_message(): def _get_random_llm_message():
""" """
Get a random message from the LLM. Get a random message from the LLM.
""" """
messages = [ messages = ["Hey how's it going?", "What's 1 + 1?"]
"Hey how's it going?",
"What's 1 + 1?"
]
return [{"role": "user", "content": random.choice(messages)}]
return [
{"role": "user", "content": random.choice(messages)}
]
def _clean_litellm_params(litellm_params: dict): def _clean_litellm_params(litellm_params: dict):
@ -44,34 +35,40 @@ async def _perform_health_check(model_list: list):
""" """
Perform a health check for each model in the list. Perform a health check for each model in the list.
""" """
async def _check_img_gen_model(model_params: dict): async def _check_img_gen_model(model_params: dict):
model_params.pop("messages", None) model_params.pop("messages", None)
model_params["prompt"] = "test from litellm" model_params["prompt"] = "test from litellm"
try: try:
await litellm.aimage_generation(**model_params) await litellm.aimage_generation(**model_params)
except Exception as e: except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}") print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False return False
return True return True
async def _check_embedding_model(model_params: dict): async def _check_embedding_model(model_params: dict):
model_params.pop("messages", None) model_params.pop("messages", None)
model_params["input"] = ["test from litellm"] model_params["input"] = ["test from litellm"]
try: try:
await litellm.aembedding(**model_params) await litellm.aembedding(**model_params)
except Exception as e: except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}") print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False return False
return True return True
async def _check_model(model_params: dict): async def _check_model(model_params: dict):
try: try:
await litellm.acompletion(**model_params) await litellm.acompletion(**model_params)
except Exception as e: except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}") print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False return False
return True return True
tasks = [] tasks = []
@ -104,9 +101,9 @@ async def _perform_health_check(model_list: list):
return healthy_endpoints, unhealthy_endpoints return healthy_endpoints, unhealthy_endpoints
async def perform_health_check(
model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None
async def perform_health_check(model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None): ):
""" """
Perform a health check on the system. Perform a health check on the system.
@ -115,7 +112,9 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl
""" """
if not model_list: if not model_list:
if cli_model: if cli_model:
model_list = [{"model_name": cli_model, "litellm_params": {"model": cli_model}}] model_list = [
{"model_name": cli_model, "litellm_params": {"model": cli_model}}
]
else: else:
return [], [] return [], []
@ -125,5 +124,3 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list) healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)
return healthy_endpoints, unhealthy_endpoints return healthy_endpoints, unhealthy_endpoints

View file

@ -6,35 +6,42 @@ from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
import json, traceback import json, traceback
class MaxBudgetLimiter(CustomLogger):
class MaxBudgetLimiter(CustomLogger):
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
pass pass
def print_verbose(self, print_statement): def print_verbose(self, print_statement):
if litellm.set_verbose is True: if litellm.set_verbose is True:
print(print_statement) # noqa print(print_statement) # noqa
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str): async def async_pre_call_hook(
try: self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook") self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id" cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
user_row = cache.get_cache(cache_key) user_row = cache.get_cache(cache_key)
if user_row is None: # value not yet cached if user_row is None: # value not yet cached
return return
max_budget = user_row["max_budget"] max_budget = user_row["max_budget"]
curr_spend = user_row["spend"] curr_spend = user_row["spend"]
if max_budget is None: if max_budget is None:
return return
if curr_spend is None: if curr_spend is None:
return return
# CHECK IF REQUEST ALLOWED # CHECK IF REQUEST ALLOWED
if curr_spend >= max_budget: if curr_spend >= max_budget:
raise HTTPException(status_code=429, detail="Max budget limit reached.") raise HTTPException(status_code=429, detail="Max budget limit reached.")
except HTTPException as e: except HTTPException as e:
raise e raise e
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()

View file

@ -5,18 +5,25 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
class MaxParallelRequestsHandler(CustomLogger):
class MaxParallelRequestsHandler(CustomLogger):
user_api_key_cache = None user_api_key_cache = None
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
pass pass
def print_verbose(self, print_statement): def print_verbose(self, print_statement):
if litellm.set_verbose is True: if litellm.set_verbose is True:
print(print_statement) # noqa print(print_statement) # noqa
async def async_pre_call_hook(
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str): self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests max_parallel_requests = user_api_key_dict.max_parallel_requests
@ -26,8 +33,8 @@ class MaxParallelRequestsHandler(CustomLogger):
if max_parallel_requests is None: if max_parallel_requests is None:
return return
self.user_api_key_cache = cache # save the api key cache for updating the value self.user_api_key_cache = cache # save the api key cache for updating the value
# CHECK IF REQUEST ALLOWED # CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count" request_count_api_key = f"{api_key}_request_count"
@ -35,56 +42,67 @@ class MaxParallelRequestsHandler(CustomLogger):
self.print_verbose(f"current: {current}") self.print_verbose(f"current: {current}")
if current is None: if current is None:
cache.set_cache(request_count_api_key, 1) cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests: elif int(current) < max_parallel_requests:
# Increase count for this token # Increase count for this token
cache.set_cache(request_count_api_key, int(current) + 1) cache.set_cache(request_count_api_key, int(current) + 1)
else: else:
raise HTTPException(status_code=429, detail="Max parallel request limit reached.") raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING") self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING")
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
if user_api_key is None: if user_api_key is None:
return return
if self.user_api_key_cache is None: if self.user_api_key_cache is None:
return return
request_count_api_key = f"{user_api_key}_request_count" request_count_api_key = f"{user_api_key}_request_count"
# check if it has collected an entire stream response # check if it has collected an entire stream response
self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}") self.print_verbose(
f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}"
)
if "complete_streaming_response" in kwargs or kwargs["stream"] != True: if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
# Decrease count for this token # Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 current = (
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
)
new_val = current - 1 new_val = current - 1
self.print_verbose(f"updated_value in success call: {new_val}") self.print_verbose(f"updated_value in success call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val) self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e: except Exception as e:
self.print_verbose(e) # noqa self.print_verbose(e) # noqa
async def async_log_failure_call(self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception): async def async_log_failure_call(
self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception
):
try: try:
self.print_verbose(f"Inside Max Parallel Request Failure Hook") self.print_verbose(f"Inside Max Parallel Request Failure Hook")
api_key = user_api_key_dict.api_key api_key = user_api_key_dict.api_key
if api_key is None: if api_key is None:
return return
if self.user_api_key_cache is None: if self.user_api_key_cache is None:
return return
## decrement call count if call failed ## decrement call count if call failed
if (hasattr(original_exception, "status_code") if (
and original_exception.status_code == 429 hasattr(original_exception, "status_code")
and "Max parallel request limit reached" in str(original_exception)): and original_exception.status_code == 429
pass # ignore failed calls due to max limit being reached and "Max parallel request limit reached" in str(original_exception)
else: ):
pass # ignore failed calls due to max limit being reached
else:
request_count_api_key = f"{api_key}_request_count" request_count_api_key = f"{api_key}_request_count"
# Decrease count for this token # Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 current = (
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
)
new_val = current - 1 new_val = current - 1
self.print_verbose(f"updated_value in failure call: {new_val}") self.print_verbose(f"updated_value in failure call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val) self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e: except Exception as e:
self.print_verbose(f"An exception occurred - {str(e)}") # noqa self.print_verbose(f"An exception occurred - {str(e)}") # noqa

View file

@ -6,85 +6,202 @@ from datetime import datetime
import importlib import importlib
from dotenv import load_dotenv from dotenv import load_dotenv
import operator import operator
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
config_filename = "litellm.secrets" config_filename = "litellm.secrets"
# Using appdirs to determine user-specific config path # Using appdirs to determine user-specific config path
config_dir = appdirs.user_config_dir("litellm") config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)) user_config_path = os.getenv(
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
)
load_dotenv() load_dotenv()
from importlib import resources from importlib import resources
import shutil import shutil
telemetry = None telemetry = None
def run_ollama_serve(): def run_ollama_serve():
try: try:
command = ['ollama', 'serve'] command = ["ollama", "serve"]
with open(os.devnull, 'w') as devnull: with open(os.devnull, "w") as devnull:
process = subprocess.Popen(command, stdout=devnull, stderr=devnull) process = subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e: except Exception as e:
print(f""" print(
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""") # noqa """
) # noqa
def clone_subfolder(repo_url, subfolder, destination): def clone_subfolder(repo_url, subfolder, destination):
# Clone the full repo # Clone the full repo
repo_name = repo_url.split('/')[-1] repo_name = repo_url.split("/")[-1]
repo_master = os.path.join(destination, "repo_master") repo_master = os.path.join(destination, "repo_master")
subprocess.run(['git', 'clone', repo_url, repo_master]) subprocess.run(["git", "clone", repo_url, repo_master])
# Move into the subfolder # Move into the subfolder
subfolder_path = os.path.join(repo_master, subfolder) subfolder_path = os.path.join(repo_master, subfolder)
# Copy subfolder to destination # Copy subfolder to destination
for file_name in os.listdir(subfolder_path): for file_name in os.listdir(subfolder_path):
source = os.path.join(subfolder_path, file_name) source = os.path.join(subfolder_path, file_name)
if os.path.isfile(source): if os.path.isfile(source):
shutil.copy(source, destination) shutil.copy(source, destination)
else: else:
dest_path = os.path.join(destination, file_name) dest_path = os.path.join(destination, file_name)
shutil.copytree(source, dest_path) shutil.copytree(source, dest_path)
# Remove cloned repo folder
subprocess.run(["rm", "-rf", os.path.join(destination, "repo_master")])
feature_telemetry(feature="create-proxy")
# Remove cloned repo folder
subprocess.run(['rm', '-rf', os.path.join(destination, "repo_master")])
feature_telemetry(feature="create-proxy")
def is_port_in_use(port): def is_port_in_use(port):
import socket import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0 return s.connect_ex(("localhost", port)) == 0
@click.command() @click.command()
@click.option('--host', default='0.0.0.0', help='Host for the server to listen on.') @click.option("--host", default="0.0.0.0", help="Host for the server to listen on.")
@click.option('--port', default=8000, help='Port to bind the server to.') @click.option("--port", default=8000, help="Port to bind the server to.")
@click.option('--num_workers', default=1, help='Number of uvicorn workers to spin up') @click.option("--num_workers", default=1, help="Number of uvicorn workers to spin up")
@click.option('--api_base', default=None, help='API base URL.') @click.option("--api_base", default=None, help="API base URL.")
@click.option('--api_version', default="2023-07-01-preview", help='For azure - pass in the api version.') @click.option(
@click.option('--model', '-m', default=None, help='The model name to pass to litellm expects') "--api_version",
@click.option('--alias', default=None, help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")') default="2023-07-01-preview",
@click.option('--add_key', default=None, help='The model name to pass to litellm expects') help="For azure - pass in the api version.",
@click.option('--headers', default=None, help='headers for the API call') )
@click.option('--save', is_flag=True, type=bool, help='Save the model-specific config') @click.option(
@click.option('--debug', default=False, is_flag=True, type=bool, help='To debug the input') "--model", "-m", default=None, help="The model name to pass to litellm expects"
@click.option('--use_queue', default=False, is_flag=True, type=bool, help='To use celery workers for async endpoints') )
@click.option('--temperature', default=None, type=float, help='Set temperature for the model') @click.option(
@click.option('--max_tokens', default=None, type=int, help='Set max tokens for the model') "--alias",
@click.option('--request_timeout', default=600, type=int, help='Set timeout in seconds for completion calls') default=None,
@click.option('--drop_params', is_flag=True, help='Drop any unmapped params') help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")',
@click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt') )
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`') @click.option(
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`') "--add_key", default=None, help="The model name to pass to litellm expects"
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`') )
@click.option('--version', '-v', default=False, is_flag=True, type=bool, help='Print LiteLLM version') @click.option("--headers", default=None, help="headers for the API call")
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.') @click.option("--save", is_flag=True, type=bool, help="Save the model-specific config")
@click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml') @click.option(
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to') "--debug", default=False, is_flag=True, type=bool, help="To debug the input"
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response') )
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with') @click.option(
@click.option('--local', is_flag=True, default=False, help='for local debugging') "--use_queue",
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health, version): default=False,
is_flag=True,
type=bool,
help="To use celery workers for async endpoints",
)
@click.option(
"--temperature", default=None, type=float, help="Set temperature for the model"
)
@click.option(
"--max_tokens", default=None, type=int, help="Set max tokens for the model"
)
@click.option(
"--request_timeout",
default=600,
type=int,
help="Set timeout in seconds for completion calls",
)
@click.option("--drop_params", is_flag=True, help="Drop any unmapped params")
@click.option(
"--add_function_to_prompt",
is_flag=True,
help="If function passed but unsupported, pass it as prompt",
)
@click.option(
"--config",
"-c",
default=None,
help="Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`",
)
@click.option(
"--max_budget",
default=None,
type=float,
help="Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`",
)
@click.option(
"--telemetry",
default=True,
type=bool,
help="Helps us know if people are using this feature. Turn this off by doing `--telemetry False`",
)
@click.option(
"--version",
"-v",
default=False,
is_flag=True,
type=bool,
help="Print LiteLLM version",
)
@click.option(
"--logs",
flag_value=False,
type=int,
help='Gets the "n" most recent logs. By default gets most recent log.',
)
@click.option(
"--health",
flag_value=True,
help="Make a chat/completions request to all llms in config.yaml",
)
@click.option(
"--test",
flag_value=True,
help="proxy chat completions url to make a test request to",
)
@click.option(
"--test_async",
default=False,
is_flag=True,
help="Calls async endpoints /queue/requests and /queue/response",
)
@click.option(
"--num_requests",
default=10,
type=int,
help="Number of requests to hit async endpoint with",
)
@click.option("--local", is_flag=True, default=False, help="for local debugging")
def run_server(
host,
port,
api_base,
api_version,
model,
alias,
add_key,
headers,
save,
debug,
temperature,
max_tokens,
request_timeout,
drop_params,
add_function_to_prompt,
config,
max_budget,
telemetry,
logs,
test,
local,
num_workers,
test_async,
num_requests,
use_queue,
health,
version,
):
global feature_telemetry global feature_telemetry
args = locals() args = locals()
if local: if local:
@ -92,51 +209,60 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
else: else:
try: try:
from .proxy_server import app, save_worker_config, usage_telemetry from .proxy_server import app, save_worker_config, usage_telemetry
except ImportError as e: except ImportError as e:
from proxy_server import app, save_worker_config, usage_telemetry from proxy_server import app, save_worker_config, usage_telemetry
feature_telemetry = usage_telemetry feature_telemetry = usage_telemetry
if logs is not None: if logs is not None:
if logs == 0: # default to 1 if logs == 0: # default to 1
logs = 1 logs = 1
try: try:
with open('api_log.json') as f: with open("api_log.json") as f:
data = json.load(f) data = json.load(f)
# convert keys to datetime objects # convert keys to datetime objects
log_times = {datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()} log_times = {
datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()
}
# sort by timestamp # sort by timestamp
sorted_times = sorted(log_times.items(), key=operator.itemgetter(0), reverse=True) sorted_times = sorted(
log_times.items(), key=operator.itemgetter(0), reverse=True
)
# get n recent logs # get n recent logs
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]} recent_logs = {
k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]
}
print(json.dumps(recent_logs, indent=4)) # noqa print(json.dumps(recent_logs, indent=4)) # noqa
except: except:
raise Exception("LiteLLM: No logs saved!") raise Exception("LiteLLM: No logs saved!")
return return
if version == True: if version == True:
pkg_version = importlib.metadata.version("litellm") pkg_version = importlib.metadata.version("litellm")
click.echo(f'\nLiteLLM: Current Version = {pkg_version}\n') click.echo(f"\nLiteLLM: Current Version = {pkg_version}\n")
return return
if model and "ollama" in model and api_base is None: if model and "ollama" in model and api_base is None:
run_ollama_serve() run_ollama_serve()
if test_async is True: if test_async is True:
import requests, concurrent, time import requests, concurrent, time
api_base = f"http://{host}:{port}" api_base = f"http://{host}:{port}"
def _make_openai_completion(): def _make_openai_completion():
data = { data = {
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Write a short poem about the moon"}] "messages": [
{"role": "user", "content": "Write a short poem about the moon"}
],
} }
response = requests.post("http://0.0.0.0:8000/queue/request", json=data) response = requests.post("http://0.0.0.0:8000/queue/request", json=data)
response = response.json() response = response.json()
while True: while True:
try: try:
url = response["url"] url = response["url"]
polling_url = f"{api_base}{url}" polling_url = f"{api_base}{url}"
polling_response = requests.get(polling_url) polling_response = requests.get(polling_url)
@ -146,7 +272,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
if status == "finished": if status == "finished":
llm_response = polling_response["result"] llm_response = polling_response["result"]
break break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa print(
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}"
) # noqa
time.sleep(0.5) time.sleep(0.5)
except Exception as e: except Exception as e:
print("got exception in polling", e) print("got exception in polling", e)
@ -159,7 +287,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
futures = [] futures = []
start_time = time.time() start_time = time.time()
# Make concurrent calls # Make concurrent calls
with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_calls) as executor: with concurrent.futures.ThreadPoolExecutor(
max_workers=concurrent_calls
) as executor:
for _ in range(concurrent_calls): for _ in range(concurrent_calls):
futures.append(executor.submit(_make_openai_completion)) futures.append(executor.submit(_make_openai_completion))
@ -171,7 +301,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
failed_calls = 0 failed_calls = 0
for future in futures: for future in futures:
if future.done(): if future.done():
if future.result() is not None: if future.result() is not None:
successful_calls += 1 successful_calls += 1
else: else:
@ -185,58 +315,86 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
return return
if health != False: if health != False:
import requests import requests
print("\nLiteLLM: Health Testing models in config") print("\nLiteLLM: Health Testing models in config")
response = requests.get(url=f"http://{host}:{port}/health") response = requests.get(url=f"http://{host}:{port}/health")
print(json.dumps(response.json(), indent=4)) print(json.dumps(response.json(), indent=4))
return return
if test != False: if test != False:
click.echo('\nLiteLLM: Making a test ChatCompletions request to your proxy') click.echo("\nLiteLLM: Making a test ChatCompletions request to your proxy")
import openai import openai
if test == True: # flag value set
api_base = f"http://{host}:{port}"
else:
api_base = test
client = openai.OpenAI(
api_key="My API Key",
base_url=api_base
)
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [ if test == True: # flag value set
{ api_base = f"http://{host}:{port}"
"role": "user", else:
"content": "this is a test request, write a short poem" api_base = test
} client = openai.OpenAI(api_key="My API Key", base_url=api_base)
], max_tokens=256)
click.echo(f'\nLiteLLM: response from proxy {response}') response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "this is a test request, write a short poem",
}
],
max_tokens=256,
)
click.echo(f"\nLiteLLM: response from proxy {response}")
print("\n Making streaming request to proxy") print("\n Making streaming request to proxy")
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [ response = client.chat.completions.create(
{ model="gpt-3.5-turbo",
"role": "user", messages=[
"content": "this is a test request, write a short poem" {
} "role": "user",
], "content": "this is a test request, write a short poem",
stream=True, }
],
stream=True,
) )
for chunk in response: for chunk in response:
click.echo(f'LiteLLM: streaming response from proxy {chunk}') click.echo(f"LiteLLM: streaming response from proxy {chunk}")
print("\n making completion request to proxy") print("\n making completion request to proxy")
response = client.completions.create(model="gpt-3.5-turbo", prompt='this is a test request, write a short poem') response = client.completions.create(
model="gpt-3.5-turbo", prompt="this is a test request, write a short poem"
)
print(response) print(response)
return return
else: else:
if headers: if headers:
headers = json.loads(headers) headers = json.loads(headers)
save_worker_config(model=model, alias=alias, api_base=api_base, api_version=api_version, debug=debug, temperature=temperature, max_tokens=max_tokens, request_timeout=request_timeout, max_budget=max_budget, telemetry=telemetry, drop_params=drop_params, add_function_to_prompt=add_function_to_prompt, headers=headers, save=save, config=config, use_queue=use_queue) save_worker_config(
model=model,
alias=alias,
api_base=api_base,
api_version=api_version,
debug=debug,
temperature=temperature,
max_tokens=max_tokens,
request_timeout=request_timeout,
max_budget=max_budget,
telemetry=telemetry,
drop_params=drop_params,
add_function_to_prompt=add_function_to_prompt,
headers=headers,
save=save,
config=config,
use_queue=use_queue,
)
try: try:
import uvicorn import uvicorn
except: except:
raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`") raise ImportError(
"Uvicorn needs to be imported. Run - `pip install uvicorn`"
)
if port == 8000 and is_port_in_use(port): if port == 8000 and is_port_in_use(port):
port = random.randint(1024, 49152) port = random.randint(1024, 49152)
uvicorn.run("litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers) uvicorn.run(
"litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers
)
if __name__ == "__main__": if __name__ == "__main__":

File diff suppressed because it is too large Load diff

View file

@ -1,71 +1,77 @@
from dotenv import load_dotenv # from dotenv import load_dotenv
load_dotenv()
import json, subprocess
import psutil # Import the psutil library
import atexit
try:
### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both
from celery import Celery
import redis
except:
import sys
subprocess.check_call( # load_dotenv()
[ # import json, subprocess
sys.executable, # import psutil # Import the psutil library
"-m", # import atexit
"pip",
"install",
"redis",
"celery"
]
)
import time # try:
import sys, os # ### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both
sys.path.insert( # from celery import Celery
0, os.path.abspath("../../..") # import redis
) # Adds the parent directory to the system path - for litellm local dev # except:
import litellm # import sys
# Redis connection setup # subprocess.check_call([sys.executable, "-m", "pip", "install", "redis", "celery"])
pool = redis.ConnectionPool(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"), db=0, max_connections=5)
redis_client = redis.Redis(connection_pool=pool)
# Celery setup # import time
celery_app = Celery('tasks', broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}", backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}") # import sys, os
celery_app.conf.update(
broker_pool_limit = None, # sys.path.insert(
broker_transport_options = {'connection_pool': pool}, # 0, os.path.abspath("../../..")
result_backend_transport_options = {'connection_pool': pool}, # ) # Adds the parent directory to the system path - for litellm local dev
) # import litellm
# # Redis connection setup
# pool = redis.ConnectionPool(
# host=os.getenv("REDIS_HOST"),
# port=os.getenv("REDIS_PORT"),
# password=os.getenv("REDIS_PASSWORD"),
# db=0,
# max_connections=5,
# )
# redis_client = redis.Redis(connection_pool=pool)
# # Celery setup
# celery_app = Celery(
# "tasks",
# broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}",
# backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}",
# )
# celery_app.conf.update(
# broker_pool_limit=None,
# broker_transport_options={"connection_pool": pool},
# result_backend_transport_options={"connection_pool": pool},
# )
# Celery task # # Celery task
@celery_app.task(name='process_job', max_retries=3) # @celery_app.task(name="process_job", max_retries=3)
def process_job(*args, **kwargs): # def process_job(*args, **kwargs):
try: # try:
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore # llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore
response = llm_router.completion(*args, **kwargs) # type: ignore # response = llm_router.completion(*args, **kwargs) # type: ignore
if isinstance(response, litellm.ModelResponse): # if isinstance(response, litellm.ModelResponse):
response = response.model_dump_json() # response = response.model_dump_json()
return json.loads(response) # return json.loads(response)
return str(response) # return str(response)
except Exception as e: # except Exception as e:
raise e # raise e
# Ensure Celery workers are terminated when the script exits
def cleanup():
try:
# Get a list of all running processes
for process in psutil.process_iter(attrs=['pid', 'name']):
# Check if the process is a Celery worker process
if process.info['name'] == 'celery':
print(f"Terminating Celery worker with PID {process.info['pid']}")
# Terminate the Celery worker process
psutil.Process(process.info['pid']).terminate()
except Exception as e:
print(f"Error during cleanup: {e}")
# Register the cleanup function to run when the script exits # # Ensure Celery workers are terminated when the script exits
atexit.register(cleanup) # def cleanup():
# try:
# # Get a list of all running processes
# for process in psutil.process_iter(attrs=["pid", "name"]):
# # Check if the process is a Celery worker process
# if process.info["name"] == "celery":
# print(f"Terminating Celery worker with PID {process.info['pid']}")
# # Terminate the Celery worker process
# psutil.Process(process.info["pid"]).terminate()
# except Exception as e:
# print(f"Error during cleanup: {e}")
# # Register the cleanup function to run when the script exits
# atexit.register(cleanup)

View file

@ -1,12 +1,15 @@
import os import os
from multiprocessing import Process from multiprocessing import Process
def run_worker(cwd): def run_worker(cwd):
os.chdir(cwd) os.chdir(cwd)
os.system("celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info") os.system(
"celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info"
)
def start_worker(cwd): def start_worker(cwd):
cwd += "/queue" cwd += "/queue"
worker_process = Process(target=run_worker, args=(cwd,)) worker_process = Process(target=run_worker, args=(cwd,))
worker_process.start() worker_process.start()

View file

@ -1,26 +1,34 @@
import sys, os # import sys, os
from dotenv import load_dotenv # from dotenv import load_dotenv
load_dotenv()
# Add the path to the local folder to sys.path
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path - for litellm local dev
def start_rq_worker(): # load_dotenv()
from rq import Worker, Queue, Connection # # Add the path to the local folder to sys.path
from redis import Redis # sys.path.insert(
# Set up RQ connection # 0, os.path.abspath("../../..")
redis_conn = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) # ) # Adds the parent directory to the system path - for litellm local dev
print(redis_conn.ping()) # Should print True if connected successfully
# Create a worker and add the queue
try:
queue = Queue(connection=redis_conn)
worker = Worker([queue], connection=redis_conn)
except Exception as e:
print(f"Error setting up worker: {e}")
exit()
with Connection(redis_conn):
worker.work()
start_rq_worker()
# def start_rq_worker():
# from rq import Worker, Queue, Connection
# from redis import Redis
# # Set up RQ connection
# redis_conn = Redis(
# host=os.getenv("REDIS_HOST"),
# port=os.getenv("REDIS_PORT"),
# password=os.getenv("REDIS_PASSWORD"),
# )
# print(redis_conn.ping()) # Should print True if connected successfully
# # Create a worker and add the queue
# try:
# queue = Queue(connection=redis_conn)
# worker = Worker([queue], connection=redis_conn)
# except Exception as e:
# print(f"Error setting up worker: {e}")
# exit()
# with Connection(redis_conn):
# worker.work()
# start_rq_worker()

View file

@ -4,20 +4,18 @@ import uuid
import traceback import traceback
litellm_client = AsyncOpenAI( litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000")
api_key="test",
base_url="http://0.0.0.0:8000"
)
async def litellm_completion(): async def litellm_completion():
# Your existing code for litellm_completion goes here # Your existing code for litellm_completion goes here
try: try:
response = await litellm_client.chat.completions.create( response = await litellm_client.chat.completions.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"*180}], # this is about 4k tokens per request messages=[
{"role": "user", "content": f"This is a test: {uuid.uuid4()}" * 180}
], # this is about 4k tokens per request
) )
print(response)
return response return response
except Exception as e: except Exception as e:
@ -25,7 +23,6 @@ async def litellm_completion():
with open("error_log.txt", "a") as error_log: with open("error_log.txt", "a") as error_log:
error_log.write(f"Error during completion: {str(e)}\n") error_log.write(f"Error during completion: {str(e)}\n")
pass pass
async def main(): async def main():
@ -45,6 +42,7 @@ async def main():
print(n, time.time() - start, len(successful_completions)) print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__": if __name__ == "__main__":
# Blank out contents of error_log.txt # Blank out contents of error_log.txt
open("error_log.txt", "w").close() open("error_log.txt", "w").close()

View file

@ -4,16 +4,13 @@ import uuid
import traceback import traceback
litellm_client = AsyncOpenAI( litellm_client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:8000")
api_key="sk-1234",
base_url="http://0.0.0.0:8000"
)
async def litellm_completion(): async def litellm_completion():
# Your existing code for litellm_completion goes here # Your existing code for litellm_completion goes here
try: try:
response = await litellm_client.chat.completions.create( response = await litellm_client.chat.completions.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
) )
@ -24,7 +21,6 @@ async def litellm_completion():
with open("error_log.txt", "a") as error_log: with open("error_log.txt", "a") as error_log:
error_log.write(f"Error during completion: {str(e)}\n") error_log.write(f"Error during completion: {str(e)}\n")
pass pass
async def main(): async def main():
@ -44,6 +40,7 @@ async def main():
print(n, time.time() - start, len(successful_completions)) print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__": if __name__ == "__main__":
# Blank out contents of error_log.txt # Blank out contents of error_log.txt
open("error_log.txt", "w").close() open("error_log.txt", "w").close()

View file

@ -1,4 +1,4 @@
# test time it takes to make 100 concurrent embedding requests to OpenaI # test time it takes to make 100 concurrent embedding requests to OpenaI
import sys, os import sys, os
import traceback import traceback
@ -14,16 +14,16 @@ import pytest
import litellm import litellm
litellm.set_verbose=False
litellm.set_verbose = False
question = "embed this very long text" * 100 question = "embed this very long text" * 100
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere # Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions # show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
import concurrent.futures import concurrent.futures
import random import random
@ -35,7 +35,10 @@ def make_openai_completion(question):
try: try:
start_time = time.time() start_time = time.time()
import openai import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY']) #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"]
) # base_url="http://0.0.0.0:8000",
response = client.embeddings.create( response = client.embeddings.create(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=[question], input=[question],
@ -58,6 +61,7 @@ def make_openai_completion(question):
# ) # )
return None return None
start_time = time.time() start_time = time.time()
# Number of concurrent calls (you can adjust this) # Number of concurrent calls (you can adjust this)
concurrent_calls = 500 concurrent_calls = 500

View file

@ -4,24 +4,20 @@ import uuid
import traceback import traceback
litellm_client = AsyncOpenAI( litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000")
api_key="test",
base_url="http://0.0.0.0:8000"
)
async def litellm_completion(): async def litellm_completion():
# Your existing code for litellm_completion goes here # Your existing code for litellm_completion goes here
try: try:
print("starting embedding calls") print("starting embedding calls")
response = await litellm_client.embeddings.create( response = await litellm_client.embeddings.create(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input = [ input=[
"hello who are you" * 2000, "hello who are you" * 2000,
"hello who are you tomorrow 1234" * 1000, "hello who are you tomorrow 1234" * 1000,
"hello who are you tomorrow 1234" * 1000 "hello who are you tomorrow 1234" * 1000,
] ],
) )
print(response) print(response)
return response return response
@ -31,7 +27,6 @@ async def litellm_completion():
with open("error_log.txt", "a") as error_log: with open("error_log.txt", "a") as error_log:
error_log.write(f"Error during completion: {str(e)}\n") error_log.write(f"Error during completion: {str(e)}\n")
pass pass
async def main(): async def main():
@ -51,6 +46,7 @@ async def main():
print(n, time.time() - start, len(successful_completions)) print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__": if __name__ == "__main__":
# Blank out contents of error_log.txt # Blank out contents of error_log.txt
open("error_log.txt", "w").close() open("error_log.txt", "w").close()

View file

@ -1,4 +1,4 @@
# test time it takes to make 100 concurrent embedding requests to OpenaI # test time it takes to make 100 concurrent embedding requests to OpenaI
import sys, os import sys, os
import traceback import traceback
@ -14,16 +14,16 @@ import pytest
import litellm import litellm
litellm.set_verbose=False
litellm.set_verbose = False
question = "embed this very long text" * 100 question = "embed this very long text" * 100
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere # Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions # show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
import concurrent.futures import concurrent.futures
import random import random
@ -35,7 +35,10 @@ def make_openai_completion(question):
try: try:
start_time = time.time() start_time = time.time()
import openai import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000"
) # base_url="http://0.0.0.0:8000",
response = client.embeddings.create( response = client.embeddings.create(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=[question], input=[question],
@ -58,6 +61,7 @@ def make_openai_completion(question):
# ) # )
return None return None
start_time = time.time() start_time = time.time()
# Number of concurrent calls (you can adjust this) # Number of concurrent calls (you can adjust this)
concurrent_calls = 500 concurrent_calls = 500

View file

@ -2,6 +2,7 @@ import requests
import time import time
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -12,37 +13,35 @@ base_url = "https://api.litellm.ai"
# Step 1 Add a config to the proxy, generate a temp key # Step 1 Add a config to the proxy, generate a temp key
config = { config = {
"model_list": [ "model_list": [
{ {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { "litellm_params": {
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"api_key": os.environ['OPENAI_API_KEY'], "api_key": os.environ["OPENAI_API_KEY"],
} },
}, },
{ {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { "litellm_params": {
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.environ['AZURE_API_KEY'], "api_key": os.environ["AZURE_API_KEY"],
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/", "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_version": "2023-07-01-preview" "api_version": "2023-07-01-preview",
} },
} },
] ]
} }
print("STARTING LOAD TEST Q") print("STARTING LOAD TEST Q")
print(os.environ['AZURE_API_KEY']) print(os.environ["AZURE_API_KEY"])
response = requests.post( response = requests.post(
url=f"{base_url}/key/generate", url=f"{base_url}/key/generate",
json={ json={
"config": config, "config": config,
"duration": "30d" # default to 30d, set it to 30m if you want a temp key "duration": "30d", # default to 30d, set it to 30m if you want a temp key
}, },
headers={ headers={"Authorization": "Bearer sk-hosted-litellm"},
"Authorization": "Bearer sk-hosted-litellm"
}
) )
print("\nresponse from generating key", response.text) print("\nresponse from generating key", response.text)
@ -56,19 +55,18 @@ print("\ngenerated key for proxy", generated_key)
import concurrent.futures import concurrent.futures
def create_job_and_poll(request_num): def create_job_and_poll(request_num):
print(f"Creating a job on the proxy for request {request_num}") print(f"Creating a job on the proxy for request {request_num}")
job_response = requests.post( job_response = requests.post(
url=f"{base_url}/queue/request", url=f"{base_url}/queue/request",
json={ json={
'model': 'gpt-3.5-turbo', "model": "gpt-3.5-turbo",
'messages': [ "messages": [
{'role': 'system', 'content': 'write a short poem'}, {"role": "system", "content": "write a short poem"},
], ],
}, },
headers={ headers={"Authorization": f"Bearer {generated_key}"},
"Authorization": f"Bearer {generated_key}"
}
) )
print(job_response.status_code) print(job_response.status_code)
print(job_response.text) print(job_response.text)
@ -84,12 +82,12 @@ def create_job_and_poll(request_num):
try: try:
print(f"\nPolling URL for request {request_num}", polling_url) print(f"\nPolling URL for request {request_num}", polling_url)
polling_response = requests.get( polling_response = requests.get(
url=polling_url, url=polling_url, headers={"Authorization": f"Bearer {generated_key}"}
headers={ )
"Authorization": f"Bearer {generated_key}" print(
} f"\nResponse from polling url for request {request_num}",
polling_response.text,
) )
print(f"\nResponse from polling url for request {request_num}", polling_response.text)
polling_response = polling_response.json() polling_response = polling_response.json()
status = polling_response.get("status", None) status = polling_response.get("status", None)
if status == "finished": if status == "finished":
@ -109,6 +107,7 @@ def create_job_and_poll(request_num):
except Exception as e: except Exception as e:
print("got exception when polling", e) print("got exception when polling", e)
# Number of requests # Number of requests
num_requests = 100 num_requests = 100
@ -118,4 +117,4 @@ with concurrent.futures.ThreadPoolExecutor(max_workers=num_requests) as executor
futures = [executor.submit(create_job_and_poll, i) for i in range(num_requests)] futures = [executor.submit(create_job_and_poll, i) for i in range(num_requests)]
# Wait for all futures to complete # Wait for all futures to complete
concurrent.futures.wait(futures) concurrent.futures.wait(futures)

View file

@ -1,4 +1,4 @@
# # This tests the litelm proxy # # This tests the litelm proxy
# # it makes async Completion requests with streaming # # it makes async Completion requests with streaming
# import openai # import openai
@ -8,14 +8,14 @@
# async def test_async_completion(): # async def test_async_completion():
# response = await ( # response = await (
# model="gpt-3.5-turbo", # model="gpt-3.5-turbo",
# prompt='this is a test request, write a short poem', # prompt='this is a test request, write a short poem',
# ) # )
# print(response) # print(response)
# print("test_streaming") # print("test_streaming")
# response = await openai.chat.completions.create( # response = await openai.chat.completions.create(
# model="gpt-3.5-turbo", # model="gpt-3.5-turbo",
# prompt='this is a test request, write a short poem', # prompt='this is a test request, write a short poem',
# stream=True # stream=True
# ) # )
@ -26,4 +26,3 @@
# import asyncio # import asyncio
# asyncio.run(test_async_completion()) # asyncio.run(test_async_completion())

View file

@ -34,7 +34,3 @@
# response = claude_chat(messages) # response = claude_chat(messages)
# print(response) # print(response)

View file

@ -2,6 +2,7 @@ import requests
import time import time
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -12,26 +13,24 @@ base_url = "https://api.litellm.ai"
# Step 1 Add a config to the proxy, generate a temp key # Step 1 Add a config to the proxy, generate a temp key
config = { config = {
"model_list": [ "model_list": [
{ {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { "litellm_params": {
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"api_key": os.environ['OPENAI_API_KEY'], "api_key": os.environ["OPENAI_API_KEY"],
} },
} }
] ]
} }
response = requests.post( response = requests.post(
url=f"{base_url}/key/generate", url=f"{base_url}/key/generate",
json={ json={
"config": config, "config": config,
"duration": "30d" # default to 30d, set it to 30m if you want a temp key "duration": "30d", # default to 30d, set it to 30m if you want a temp key
}, },
headers={ headers={"Authorization": "Bearer sk-hosted-litellm"},
"Authorization": "Bearer sk-hosted-litellm"
}
) )
print("\nresponse from generating key", response.text) print("\nresponse from generating key", response.text)
@ -45,22 +44,23 @@ print("Creating a job on the proxy")
job_response = requests.post( job_response = requests.post(
url=f"{base_url}/queue/request", url=f"{base_url}/queue/request",
json={ json={
'model': 'gpt-3.5-turbo', "model": "gpt-3.5-turbo",
'messages': [ "messages": [
{'role': 'system', 'content': f'You are a helpful assistant. What is your name'}, {
"role": "system",
"content": f"You are a helpful assistant. What is your name",
},
], ],
}, },
headers={ headers={"Authorization": f"Bearer {generated_key}"},
"Authorization": f"Bearer {generated_key}"
}
) )
print(job_response.status_code) print(job_response.status_code)
print(job_response.text) print(job_response.text)
print("\nResponse from creating job", job_response.text) print("\nResponse from creating job", job_response.text)
job_response = job_response.json() job_response = job_response.json()
job_id = job_response["id"] # type: ignore job_id = job_response["id"] # type: ignore
polling_url = job_response["url"] # type: ignore polling_url = job_response["url"] # type: ignore
polling_url = f"{base_url}{polling_url}" polling_url = f"{base_url}{polling_url}"
print("\nCreated Job, Polling Url", polling_url) print("\nCreated Job, Polling Url", polling_url)
# Step 3: Poll the request # Step 3: Poll the request
@ -68,16 +68,13 @@ while True:
try: try:
print("\nPolling URL", polling_url) print("\nPolling URL", polling_url)
polling_response = requests.get( polling_response = requests.get(
url=polling_url, url=polling_url, headers={"Authorization": f"Bearer {generated_key}"}
headers={
"Authorization": f"Bearer {generated_key}"
}
) )
print("\nResponse from polling url", polling_response.text) print("\nResponse from polling url", polling_response.text)
polling_response = polling_response.json() polling_response = polling_response.json()
status = polling_response.get("status", None) # type: ignore status = polling_response.get("status", None) # type: ignore
if status == "finished": if status == "finished":
llm_response = polling_response["result"] # type: ignore llm_response = polling_response["result"] # type: ignore
print("LLM Response") print("LLM Response")
print(llm_response) print(llm_response)
break break

View file

@ -8,16 +8,19 @@ from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException, status from fastapi import HTTPException, status
def print_verbose(print_statement): def print_verbose(print_statement):
if litellm.set_verbose: if litellm.set_verbose:
print(f"LiteLLM Proxy: {print_statement}") # noqa print(f"LiteLLM Proxy: {print_statement}") # noqa
### LOGGING ###
class ProxyLogging:
### LOGGING ###
class ProxyLogging:
""" """
Logging/Custom Handlers for proxy. Logging/Custom Handlers for proxy.
Implemented mainly to: Implemented mainly to:
- log successful/failed db read/writes - log successful/failed db read/writes
- support the max parallel request integration - support the max parallel request integration
""" """
@ -25,15 +28,15 @@ class ProxyLogging:
## INITIALIZE LITELLM CALLBACKS ## ## INITIALIZE LITELLM CALLBACKS ##
self.call_details: dict = {} self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler() self.max_parallel_request_limiter = MaxParallelRequestsHandler()
self.max_budget_limiter = MaxBudgetLimiter() self.max_budget_limiter = MaxBudgetLimiter()
pass pass
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!") print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_budget_limiter) litellm.callbacks.append(self.max_budget_limiter)
for callback in litellm.callbacks: for callback in litellm.callbacks:
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
litellm.input_callback.append(callback) litellm.input_callback.append(callback)
if callback not in litellm.success_callback: if callback not in litellm.success_callback:
@ -44,7 +47,7 @@ class ProxyLogging:
litellm._async_success_callback.append(callback) litellm._async_success_callback.append(callback)
if callback not in litellm._async_failure_callback: if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback) litellm._async_failure_callback.append(callback)
if ( if (
len(litellm.input_callback) > 0 len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0 or len(litellm.success_callback) > 0
@ -57,31 +60,41 @@ class ProxyLogging:
+ litellm.failure_callback + litellm.failure_callback
) )
) )
litellm.utils.set_callbacks( litellm.utils.set_callbacks(callback_list=callback_list)
callback_list=callback_list
)
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]): async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["completion", "embeddings"],
):
""" """
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
Covers: Covers:
1. /chat/completions 1. /chat/completions
2. /embeddings 2. /embeddings
""" """
try: try:
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__): if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type) callback.__class__
if response is not None: ):
response = await callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
data=data,
call_type=call_type,
)
if response is not None:
data = response data = response
print_verbose(f'final data being sent to {call_type} call: {data}') print_verbose(f"final data being sent to {call_type} call: {data}")
return data return data
except Exception as e: except Exception as e:
raise e raise e
async def success_handler(self, *args, **kwargs): async def success_handler(self, *args, **kwargs):
""" """
Log successful db read/writes Log successful db read/writes
""" """
@ -93,26 +106,31 @@ class ProxyLogging:
Currently only logs exceptions to sentry Currently only logs exceptions to sentry
""" """
if litellm.utils.capture_exception: if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception) litellm.utils.capture_exception(error=original_exception)
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth): async def post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
""" """
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body. Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
Covers: Covers:
1. /chat/completions 1. /chat/completions
2. /embeddings 2. /embeddings
""" """
for callback in litellm.callbacks: for callback in litellm.callbacks:
try: try:
if isinstance(callback, CustomLogger): if isinstance(callback, CustomLogger):
await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception) await callback.async_post_call_failure_hook(
except Exception as e: user_api_key_dict=user_api_key_dict,
original_exception=original_exception,
)
except Exception as e:
raise e raise e
return return
### DB CONNECTOR ### ### DB CONNECTOR ###
# Define the retry decorator with backoff strategy # Define the retry decorator with backoff strategy
@ -121,9 +139,12 @@ def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries # The 'tries' key in the details dictionary contains the number of completed tries
print_verbose(f"Backing off... this was attempt #{details['tries']}") print_verbose(f"Backing off... this was attempt #{details['tries']}")
class PrismaClient: class PrismaClient:
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
## init logging object ## init logging object
self.proxy_logging_obj = proxy_logging_obj self.proxy_logging_obj = proxy_logging_obj
@ -136,23 +157,24 @@ class PrismaClient:
os.chdir(dname) os.chdir(dname)
try: try:
subprocess.run(['prisma', 'generate']) subprocess.run(["prisma", "generate"])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss subprocess.run(
["prisma", "db", "push", "--accept-data-loss"]
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
# Now you can import the Prisma Client # Now you can import the Prisma Client
from prisma import Client # type: ignore from prisma import Client # type: ignore
self.db = Client() #Client to connect to Prisma db
self.db = Client() # Client to connect to Prisma db
def hash_token(self, token: str): def hash_token(self, token: str):
# Hash the string using SHA-256 # Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest() hashed_token = hashlib.sha256(token.encode()).hexdigest()
return hashed_token return hashed_token
def jsonify_object(self, data: dict) -> dict: def jsonify_object(self, data: dict) -> dict:
db_data = copy.deepcopy(data) db_data = copy.deepcopy(data)
for k, v in db_data.items(): for k, v in db_data.items():
@ -162,233 +184,258 @@ class PrismaClient:
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
Exception, # base exception to catch for the backoff Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None): async def get_data(
try: self,
token: Optional[str] = None,
expires: Optional[Any] = None,
user_id: Optional[str] = None,
):
try:
response = None response = None
if token is not None: if token is not None:
# check if plain text or hash # check if plain text or hash
hashed_token = token hashed_token = token
if token.startswith("sk-"): if token.startswith("sk-"):
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
response = await self.db.litellm_verificationtoken.find_unique( response = await self.db.litellm_verificationtoken.find_unique(
where={ where={"token": hashed_token}
"token": hashed_token )
}
)
if response: if response:
# Token exists, now check expiration. # Token exists, now check expiration.
if response.expires is not None and expires is not None: if response.expires is not None and expires is not None:
if response.expires >= expires: if response.expires >= expires:
# Token exists and is not expired. # Token exists and is not expired.
return response return response
else: else:
# Token exists but is expired. # Token exists but is expired.
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key") raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key",
)
return response return response
else: else:
# Token does not exist. # Token does not exist.
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key") raise HTTPException(
elif user_id is not None: status_code=status.HTTP_401_UNAUTHORIZED,
response = await self.db.litellm_usertable.find_unique( # type: ignore detail="invalid user key",
where={ )
"user_id": user_id, elif user_id is not None:
} response = await self.db.litellm_usertable.find_unique( # type: ignore
) where={
"user_id": user_id,
}
)
return response return response
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e raise e
# Define a retrying strategy with exponential backoff # Define a retrying strategy with exponential backoff
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
Exception, # base exception to catch for the backoff Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def insert_data(self, data: dict): async def insert_data(self, data: dict):
""" """
Add a key to the database. If it already exists, do nothing. Add a key to the database. If it already exists, do nothing.
""" """
try: try:
token = data["token"] token = data["token"]
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token db_data["token"] = hashed_token
max_budget = db_data.pop("max_budget", None) max_budget = db_data.pop("max_budget", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={ where={
'token': hashed_token, "token": hashed_token,
}, },
data={ data={
"create": {**db_data}, #type: ignore "create": {**db_data}, # type: ignore
"update": {} # don't do anything if it already exists "update": {}, # don't do anything if it already exists
} },
) )
new_user_row = await self.db.litellm_usertable.upsert( new_user_row = await self.db.litellm_usertable.upsert(
where={ where={"user_id": data["user_id"]},
'user_id': data['user_id']
},
data={ data={
"create": {"user_id": data['user_id'], "max_budget": max_budget}, "create": {"user_id": data["user_id"], "max_budget": max_budget},
"update": {} # don't do anything if it already exists "update": {}, # don't do anything if it already exists
} },
) )
return new_verification_token return new_verification_token
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e raise e
# Define a retrying strategy with exponential backoff # Define a retrying strategy with exponential backoff
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
Exception, # base exception to catch for the backoff Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None): async def update_data(
self,
token: Optional[str] = None,
data: dict = {},
user_id: Optional[str] = None,
):
""" """
Update existing data Update existing data
""" """
try: try:
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
if token is not None: if token is not None:
print_verbose(f"token: {token}") print_verbose(f"token: {token}")
# check if plain text or hash # check if plain text or hash
if token.startswith("sk-"): if token.startswith("sk-"):
token = self.hash_token(token=token) token = self.hash_token(token=token)
db_data["token"] = token db_data["token"] = token
response = await self.db.litellm_verificationtoken.update( response = await self.db.litellm_verificationtoken.update(
where={ where={"token": token}, # type: ignore
"token": token # type: ignore data={**db_data}, # type: ignore
},
data={**db_data} # type: ignore
) )
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": db_data} return {"token": token, "data": db_data}
elif user_id is not None: elif user_id is not None:
""" """
If data['spend'] + data['user'], update the user table with spend info as well If data['spend'] + data['user'], update the user table with spend info as well
""" """
update_user_row = await self.db.litellm_usertable.update( update_user_row = await self.db.litellm_usertable.update(
where={ where={"user_id": user_id}, # type: ignore
'user_id': user_id # type: ignore data={**db_data}, # type: ignore
},
data={**db_data} # type: ignore
) )
return {"user_id": user_id, "data": db_data} return {"user_id": user_id, "data": db_data}
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
raise e raise e
# Define a retrying strategy with exponential backoff # Define a retrying strategy with exponential backoff
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
Exception, # base exception to catch for the backoff Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def delete_data(self, tokens: List): async def delete_data(self, tokens: List):
""" """
Allow user to delete a key(s) Allow user to delete a key(s)
""" """
try: try:
hashed_tokens = [self.hash_token(token=token) for token in tokens] hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many( await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}} where={"token": {"in": hashed_tokens}}
) )
return {"deleted_keys": tokens} return {"deleted_keys": tokens}
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
raise e self.proxy_logging_obj.failure_handler(original_exception=e)
)
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def connect(self):
try:
await self.db.connect()
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
raise e raise e
# Define a retrying strategy with exponential backoff # Define a retrying strategy with exponential backoff
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
Exception, # base exception to catch for the backoff Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def disconnect(self): async def connect(self):
try:
await self.db.connect()
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def disconnect(self):
try: try:
await self.db.disconnect() await self.db.disconnect()
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e raise e
### CUSTOM FILE ### ### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try: try:
print_verbose(f"value: {value}") print_verbose(f"value: {value}")
# Split the path by dots to separate module from instance # Split the path by dots to separate module from instance
parts = value.split(".") parts = value.split(".")
# The module path is all but the last part, and the instance_name is the last part # The module path is all but the last part, and the instance_name is the last part
module_name = ".".join(parts[:-1]) module_name = ".".join(parts[:-1])
instance_name = parts[-1] instance_name = parts[-1]
# If config_file_path is provided, use it to determine the module spec and load the module # If config_file_path is provided, use it to determine the module spec and load the module
if config_file_path is not None: if config_file_path is not None:
directory = os.path.dirname(config_file_path) directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, *module_name.split('.')) module_file_path = os.path.join(directory, *module_name.split("."))
module_file_path += '.py' module_file_path += ".py"
spec = importlib.util.spec_from_file_location(module_name, module_file_path) spec = importlib.util.spec_from_file_location(module_name, module_file_path)
if spec is None: if spec is None:
raise ImportError(f"Could not find a module specification for {module_file_path}") raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore spec.loader.exec_module(module) # type: ignore
else: else:
# Dynamically import the module # Dynamically import the module
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
# Get the instance from the module # Get the instance from the module
instance = getattr(module, instance_name) instance = getattr(module, instance_name)
return instance return instance
except ImportError as e: except ImportError as e:
# Re-raise the exception with a user-friendly message # Re-raise the exception with a user-friendly message
raise ImportError(f"Could not import {instance_name} from {module_name}") from e raise ImportError(f"Could not import {instance_name} from {module_name}") from e
except Exception as e: except Exception as e:
raise e raise e
### HELPER FUNCTIONS ### ### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
""" """
Check if a user_id exists in cache, Check if a user_id exists in cache,
if not retrieve it. if not retrieve it.
""" """
cache_key = f"{user_id}_user_api_key_user_id" cache_key = f"{user_id}_user_api_key_user_id"
response = cache.get_cache(key=cache_key) response = cache.get_cache(key=cache_key)
if response is None: # Cache miss if response is None: # Cache miss
user_row = await db.get_data(user_id=user_id) user_row = await db.get_data(user_id=user_id)
cache_value = user_row.model_dump_json() cache_value = user_row.model_dump_json()
cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes cache.set_cache(
return key=cache_key, value=cache_value, ttl=600
) # store for 10 minutes
return

File diff suppressed because it is too large Load diff

View file

@ -1,96 +1,121 @@
#### What this does #### #### What this does ####
# identifies least busy deployment # identifies least busy deployment
# How is this achieved? # How is this achieved?
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"} # - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic} # - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# - use litellm.success + failure callbacks to log when a request completed # - use litellm.success + failure callbacks to log when a request completed
# - in get_available_deployment, for a given model group name -> pick based on traffic # - in get_available_deployment, for a given model group name -> pick based on traffic
import dotenv, os, requests import dotenv, os, requests
from typing import Optional from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
class LeastBusyLoggingHandler(CustomLogger):
class LeastBusyLoggingHandler(CustomLogger):
def __init__(self, router_cache: DualCache): def __init__(self, router_cache: DualCache):
self.router_cache = router_cache self.router_cache = router_cache
self.mapping_deployment_to_id: dict = {} self.mapping_deployment_to_id: dict = {}
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
""" """
Log when a model is being used. Log when a model is being used.
Caching based on model group. Caching based on model group.
""" """
try: try:
if kwargs["litellm_params"].get("metadata") is None:
if kwargs['litellm_params'].get('metadata') is None:
pass pass
else: else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None) deployment = kwargs["litellm_params"]["metadata"].get(
model_group = kwargs['litellm_params']['metadata'].get('model_group', None) "deployment", None
id = kwargs['litellm_params'].get('model_info', {}).get('id', None) )
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if deployment is None or model_group is None or id is None: if deployment is None or model_group is None or id is None:
return return
# map deployment to id # map deployment to id
self.mapping_deployment_to_id[deployment] = id self.mapping_deployment_to_id[deployment] = id
request_count_api_key = f"{model_group}_request_count" request_count_api_key = f"{model_group}_request_count"
# update cache # update cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} request_count_dict = (
request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1 self.router_cache.get_cache(key=request_count_api_key) or {}
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) )
request_count_dict[deployment] = (
request_count_dict.get(deployment, 0) + 1
)
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e: except Exception as e:
pass pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
if kwargs['litellm_params'].get('metadata') is None: if kwargs["litellm_params"].get("metadata") is None:
pass pass
else: else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None) deployment = kwargs["litellm_params"]["metadata"].get(
model_group = kwargs['litellm_params']['metadata'].get('model_group', None) "deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
if deployment is None or model_group is None: if deployment is None or model_group is None:
return return
request_count_api_key = f"{model_group}_request_count" request_count_api_key = f"{model_group}_request_count"
# decrement count in cache # decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = request_count_dict.get(deployment) request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e: except Exception as e:
pass pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try: try:
if kwargs['litellm_params'].get('metadata') is None: if kwargs["litellm_params"].get("metadata") is None:
pass pass
else: else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None) deployment = kwargs["litellm_params"]["metadata"].get(
model_group = kwargs['litellm_params']['metadata'].get('model_group', None) "deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
if deployment is None or model_group is None: if deployment is None or model_group is None:
return return
request_count_api_key = f"{model_group}_request_count" request_count_api_key = f"{model_group}_request_count"
# decrement count in cache # decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = request_count_dict.get(deployment) request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e: except Exception as e:
pass pass
def get_available_deployments(self, model_group: str): def get_available_deployments(self, model_group: str):
request_count_api_key = f"{model_group}_request_count" request_count_api_key = f"{model_group}_request_count"
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
# map deployment to id # map deployment to id
return_dict = {} return_dict = {}
for key, value in request_count_dict.items(): for key, value in request_count_dict.items():
return_dict[self.mapping_deployment_to_id[key]] = value return_dict[self.mapping_deployment_to_id[key]] = value
return return_dict return return_dict

View file

@ -2,6 +2,7 @@
import pytest, sys, os import pytest, sys, os
import importlib import importlib
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
@ -11,24 +12,30 @@ import litellm
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def setup_and_teardown(): def setup_and_teardown():
""" """
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
""" """
curr_dir = os.getcwd() # Get the current working directory curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
import litellm import litellm
importlib.reload(litellm) importlib.reload(litellm)
print(litellm) print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding # from litellm import Router, completion, aembedding, acompletion, embedding
yield yield
def pytest_collection_modifyitems(config, items): def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name] custom_logger_tests = [
other_tests = [item for item in items if 'custom_logger' not in item.parent.name] item for item in items if "custom_logger" in item.parent.name
]
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
# Sort tests based on their names # Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name) custom_logger_tests.sort(key=lambda x: x.name)
other_tests.sort(key=lambda x: x.name) other_tests.sort(key=lambda x: x.name)
# Reorder the items list # Reorder the items list
items[:] = custom_logger_tests + other_tests items[:] = custom_logger_tests + other_tests

View file

@ -4,6 +4,7 @@
import sys, os, time import sys, os, time
import traceback, asyncio import traceback, asyncio
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
@ -11,53 +12,62 @@ import litellm
from litellm import Router from litellm import Router
import concurrent import concurrent
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
model_list = [{ # list of model deployments model_list = [
"model_name": "gpt-3.5-turbo", # openai model name { # list of model deployments
"litellm_params": { # params for litellm completion/embedding call "model_name": "gpt-3.5-turbo", # openai model name
"model": "azure/chatgpt-v-2", "litellm_params": { # params for litellm completion/embedding call
"api_key": "bad-key", "model": "azure/chatgpt-v-2",
"api_version": os.getenv("AZURE_API_VERSION"), "api_key": "bad-key",
"api_base": os.getenv("AZURE_API_BASE") "api_version": os.getenv("AZURE_API_VERSION"),
}, "api_base": os.getenv("AZURE_API_BASE"),
"tpm": 240000, },
"rpm": 1800, "tpm": 240000,
}, "rpm": 1800,
{ },
"model_name": "gpt-3.5-turbo", # openai model name {
"litellm_params": { # params for litellm completion/embedding call "model_name": "gpt-3.5-turbo", # openai model name
"model": "gpt-3.5-turbo", "litellm_params": { # params for litellm completion/embedding call
"api_key": os.getenv("OPENAI_API_KEY"), "model": "gpt-3.5-turbo",
}, "api_key": os.getenv("OPENAI_API_KEY"),
"tpm": 1000000, },
"rpm": 9000 "tpm": 1000000,
} "rpm": 9000,
},
] ]
kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],} kwargs = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
}
def test_multiple_deployments_sync(): def test_multiple_deployments_sync():
import concurrent, time import concurrent, time
litellm.set_verbose=False
results = [] litellm.set_verbose = False
router = Router(model_list=model_list, results = []
redis_host=os.getenv("REDIS_HOST"), router = Router(
redis_password=os.getenv("REDIS_PASSWORD"), model_list=model_list,
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore redis_host=os.getenv("REDIS_HOST"),
routing_strategy="simple-shuffle", redis_password=os.getenv("REDIS_PASSWORD"),
set_verbose=True, redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
num_retries=1) # type: ignore routing_strategy="simple-shuffle",
try: set_verbose=True,
for _ in range(3): num_retries=1,
response = router.completion(**kwargs) ) # type: ignore
results.append(response) try:
print(results) for _ in range(3):
router.reset() response = router.completion(**kwargs)
except Exception as e: results.append(response)
print(f"FAILED TEST!") print(results)
pytest.fail(f"An error occurred - {traceback.format_exc()}") router.reset()
except Exception as e:
print(f"FAILED TEST!")
pytest.fail(f"An error occurred - {traceback.format_exc()}")
# test_multiple_deployments_sync() # test_multiple_deployments_sync()
@ -67,13 +77,15 @@ def test_multiple_deployments_parallel():
results = [] results = []
futures = {} futures = {}
start_time = time.time() start_time = time.time()
router = Router(model_list=model_list, router = Router(
redis_host=os.getenv("REDIS_HOST"), model_list=model_list,
redis_password=os.getenv("REDIS_PASSWORD"), redis_host=os.getenv("REDIS_HOST"),
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore redis_password=os.getenv("REDIS_PASSWORD"),
routing_strategy="simple-shuffle", redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
set_verbose=True, routing_strategy="simple-shuffle",
num_retries=1) # type: ignore set_verbose=True,
num_retries=1,
) # type: ignore
# Assuming you have an executor instance defined somewhere in your code # Assuming you have an executor instance defined somewhere in your code
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
for _ in range(5): for _ in range(5):
@ -82,7 +94,11 @@ def test_multiple_deployments_parallel():
# Retrieve the results from the futures # Retrieve the results from the futures
while futures: while futures:
done, not_done = concurrent.futures.wait(futures.values(), timeout=10, return_when=concurrent.futures.FIRST_COMPLETED) done, not_done = concurrent.futures.wait(
futures.values(),
timeout=10,
return_when=concurrent.futures.FIRST_COMPLETED,
)
for future in done: for future in done:
try: try:
result = future.result() result = future.result()
@ -98,12 +114,14 @@ def test_multiple_deployments_parallel():
print(results) print(results)
print(f"ELAPSED TIME: {end_time - start_time}") print(f"ELAPSED TIME: {end_time - start_time}")
# Assuming litellm, router, and executor are defined somewhere in your code # Assuming litellm, router, and executor are defined somewhere in your code
# test_multiple_deployments_parallel() # test_multiple_deployments_parallel()
def test_cooldown_same_model_name(): def test_cooldown_same_model_name():
# users could have the same model with different api_base # users could have the same model with different api_base
# example # example
# azure/chatgpt, api_base: 1234 # azure/chatgpt, api_base: 1234
# azure/chatgpt, api_base: 1235 # azure/chatgpt, api_base: 1235
# if 1234 fails, it should only cooldown 1234 and then try with 1235 # if 1234 fails, it should only cooldown 1234 and then try with 1235
@ -118,7 +136,7 @@ def test_cooldown_same_model_name():
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": "BAD_API_BASE", "api_base": "BAD_API_BASE",
"tpm": 90 "tpm": 90,
}, },
}, },
{ {
@ -128,7 +146,7 @@ def test_cooldown_same_model_name():
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
"tpm": 0.000001 "tpm": 0.000001,
}, },
}, },
] ]
@ -140,17 +158,12 @@ def test_cooldown_same_model_name():
redis_port=int(os.getenv("REDIS_PORT")), redis_port=int(os.getenv("REDIS_PORT")),
routing_strategy="simple-shuffle", routing_strategy="simple-shuffle",
set_verbose=True, set_verbose=True,
num_retries=3 num_retries=3,
) # type: ignore ) # type: ignore
response = router.completion( response = router.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[ messages=[{"role": "user", "content": "hello this request will pass"}],
{
"role": "user",
"content": "hello this request will pass"
}
]
) )
print(router.model_list) print(router.model_list)
model_ids = [] model_ids = []
@ -159,10 +172,13 @@ def test_cooldown_same_model_name():
print("\n litellm model ids ", model_ids) print("\n litellm model ids ", model_ids)
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960'] # example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
assert model_ids[0] != model_ids[1] # ensure both models have a uuid added, and they have different names assert (
model_ids[0] != model_ids[1]
) # ensure both models have a uuid added, and they have different names
print("\ngot response\n", response) print("\ngot response\n", response)
except Exception as e: except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {e}") pytest.fail(f"Got unexpected exception on router! - {e}")
test_cooldown_same_model_name() test_cooldown_same_model_name()

View file

@ -9,68 +9,70 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
## case 1: set_function_to_prompt not set
## case 1: set_function_to_prompt not set
def test_function_call_non_openai_model(): def test_function_call_non_openai_model():
try: try:
model = "claude-instant-1" model = "claude-instant-1"
messages=[{"role": "user", "content": "what's the weather in sf?"}] messages = [{"role": "user", "content": "what's the weather in sf?"}]
functions = [ functions = [
{ {
"name": "get_current_weather", "name": "get_current_weather",
"description": "Get the current weather in a given location", "description": "Get the current weather in a given location",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA" "description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
}, },
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
} }
] ]
response = litellm.completion(model=model, messages=messages, functions=functions) response = litellm.completion(
pytest.fail(f'An error occurred') model=model, messages=messages, functions=functions
except Exception as e: )
pytest.fail(f"An error occurred")
except Exception as e:
print(e) print(e)
pass pass
test_function_call_non_openai_model() test_function_call_non_openai_model()
## case 2: add_function_to_prompt set
## case 2: add_function_to_prompt set
def test_function_call_non_openai_model_litellm_mod_set(): def test_function_call_non_openai_model_litellm_mod_set():
litellm.add_function_to_prompt = True litellm.add_function_to_prompt = True
try: try:
model = "claude-instant-1" model = "claude-instant-1"
messages=[{"role": "user", "content": "what's the weather in sf?"}] messages = [{"role": "user", "content": "what's the weather in sf?"}]
functions = [ functions = [
{ {
"name": "get_current_weather", "name": "get_current_weather",
"description": "Get the current weather in a given location", "description": "Get the current weather in a given location",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA" "description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
}, },
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
} }
] ]
response = litellm.completion(model=model, messages=messages, functions=functions) response = litellm.completion(
print(f'response: {response}') model=model, messages=messages, functions=functions
except Exception as e: )
pytest.fail(f'An error occurred {e}') print(f"response: {response}")
except Exception as e:
pytest.fail(f"An error occurred {e}")
# test_function_call_non_openai_model_litellm_mod_set()
# test_function_call_non_openai_model_litellm_mod_set()

View file

@ -1,4 +1,3 @@
import sys, os import sys, os
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
@ -8,7 +7,7 @@ import os, io
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest, asyncio import pytest, asyncio
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout, acompletion from litellm import embedding, completion, completion_cost, Timeout, acompletion
@ -20,18 +19,18 @@ import tempfile
litellm.num_retries = 3 litellm.num_retries = 3
litellm.cache = None litellm.cache = None
user_message = "Write a short poem about the sky" user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
def load_vertex_ai_credentials(): def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file # Define the path to the vertex_key.json file
print("loading vertex ai credentials") print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + '/vertex_key.json' vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary # Read the existing content of the file or create an empty dictionary
try: try:
with open(vertex_key_path, 'r') as file: with open(vertex_key_path, "r") as file:
# Read the file content # Read the file content
print("Read vertexai file path") print("Read vertexai file path")
content = file.read() content = file.read()
@ -55,13 +54,13 @@ def load_vertex_ai_credentials():
service_account_key_data["private_key"] = private_key service_account_key_data["private_key"] = private_key
# Create a temporary file # Create a temporary file
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file: with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary file # Write the updated content to the temporary file
json.dump(service_account_key_data, temp_file, indent=2) json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name) os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
@pytest.mark.asyncio @pytest.mark.asyncio
async def get_response(): async def get_response():
@ -89,43 +88,80 @@ def test_vertex_ai():
import random import random
load_vertex_ai_credentials() load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = (
litellm.set_verbose=False litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
litellm.set_verbose = False
litellm.vertex_project = "reliablekeys" litellm.vertex_project = "reliablekeys"
test_models = random.sample(test_models, 1) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
try: try:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model # our account does not have access to this model
continue continue
print("making request", model) print("making request", model)
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}], temperature=0.7) response = completion(
model=model,
messages=[{"role": "user", "content": "hi"}],
temperature=0.7,
)
print("\nModel Response", response) print("\nModel Response", response)
print(response) print(response)
assert type(response.choices[0].message.content) == str assert type(response.choices[0].message.content) == str
assert len(response.choices[0].message.content) > 1 assert len(response.choices[0].message.content) > 1
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_vertex_ai() # test_vertex_ai()
def test_vertex_ai_stream(): def test_vertex_ai_stream():
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose=False litellm.set_verbose = False
litellm.vertex_project = "reliablekeys" litellm.vertex_project = "reliablekeys"
import random import random
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
try: try:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model # our account does not have access to this model
continue continue
print("making request", model) print("making request", model)
response = completion(model=model, messages=[{"role": "user", "content": "write 10 line code code for saying hi"}], stream=True) response = completion(
model=model,
messages=[
{"role": "user", "content": "write 10 line code code for saying hi"}
],
stream=True,
)
completed_str = "" completed_str = ""
for chunk in response: for chunk in response:
print(chunk) print(chunk)
@ -137,47 +173,86 @@ def test_vertex_ai_stream():
assert len(completed_str) > 4 assert len(completed_str) > 4
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_vertex_ai_stream()
# test_vertex_ai_stream()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_vertexai_response(): async def test_async_vertexai_response():
import random import random
load_vertex_ai_credentials() load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
print(f'model being tested in async call: {model}') print(f"model being tested in async call: {model}")
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in [
# our account does not have access to this model "code-gecko",
continue "code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model
continue
try: try:
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5) response = await acompletion(
model=model, messages=messages, temperature=0.7, timeout=5
)
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_response()) # asyncio.run(test_async_vertexai_response())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_vertexai_streaming_response(): async def test_async_vertexai_streaming_response():
import random import random
load_vertex_ai_credentials() load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in [
# our account does not have access to this model "code-gecko",
continue "code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model
continue
try: try:
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True) response = await acompletion(
model="gemini-pro",
messages=messages,
temperature=0.7,
timeout=5,
stream=True,
)
print(f"response: {response}") print(f"response: {response}")
complete_response = "" complete_response = ""
async for chunk in response: async for chunk in response:
@ -185,44 +260,46 @@ async def test_async_vertexai_streaming_response():
complete_response += chunk.choices[0].delta.content complete_response += chunk.choices[0].delta.content
print(f"complete_response: {complete_response}") print(f"complete_response: {complete_response}")
assert len(complete_response) > 0 assert len(complete_response) > 0
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
print(e) print(e)
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_streaming_response()) # asyncio.run(test_async_vertexai_streaming_response())
def test_gemini_pro_vision(): def test_gemini_pro_vision():
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
litellm.num_retries=0 litellm.num_retries = 0
resp = litellm.completion( resp = litellm.completion(
model = "vertex_ai/gemini-pro-vision", model="vertex_ai/gemini-pro-vision",
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "Whats in this image?"},
"type": "text", {
"text": "Whats in this image?" "type": "image_url",
}, "image_url": {
{ "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
"type": "image_url", },
"image_url": { },
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" ],
}
}
]
} }
], ],
) )
print(resp) print(resp)
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise e raise e
# test_gemini_pro_vision() # test_gemini_pro_vision()
@ -333,4 +410,4 @@ def test_gemini_pro_vision():
# import traceback # import traceback
# traceback.print_exc() # traceback.print_exc()
# raise e # raise e
# test_gemini_pro_vision_async() # test_gemini_pro_vision_async()

View file

@ -11,8 +11,10 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
from litellm import completion, acompletion, acreate from litellm import completion, acompletion, acreate
litellm.num_retries = 3 litellm.num_retries = 3
def test_sync_response(): def test_sync_response():
litellm.set_verbose = False litellm.set_verbose = False
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
@ -20,35 +22,49 @@ def test_sync_response():
try: try:
response = completion(model="gpt-3.5-turbo", messages=messages, timeout=5) response = completion(model="gpt-3.5-turbo", messages=messages, timeout=5)
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_sync_response() # test_sync_response()
def test_sync_response_anyscale(): def test_sync_response_anyscale():
litellm.set_verbose = False litellm.set_verbose = False
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5) response = completion(
except litellm.Timeout as e: model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
timeout=5,
)
except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_sync_response_anyscale() # test_sync_response_anyscale()
def test_async_response_openai(): def test_async_response_openai():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
async def test_get_response(): async def test_get_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages, timeout=5) response = await acompletion(
model="gpt-3.5-turbo", messages=messages, timeout=5
)
print(f"response: {response}") print(f"response: {response}")
print(f"response ms: {response._response_ms}") print(f"response ms: {response._response_ms}")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
@ -56,54 +72,75 @@ def test_async_response_openai():
asyncio.run(test_get_response()) asyncio.run(test_get_response())
# test_async_response_openai() # test_async_response_openai()
def test_async_response_azure(): def test_async_response_azure():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
async def test_get_response(): async def test_get_response():
user_message = "What do you know?" user_message = "What do you know?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY")) response = await acompletion(
model="azure/gpt-turbo",
messages=messages,
base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"),
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
)
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response()) asyncio.run(test_get_response())
# test_async_response_azure() # test_async_response_azure()
def test_async_anyscale_response(): def test_async_anyscale_response():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
async def test_get_response(): async def test_get_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5) response = await acompletion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
timeout=5,
)
# response = await response # response = await response
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response()) asyncio.run(test_get_response())
# test_async_anyscale_response() # test_async_anyscale_response()
def test_get_response_streaming(): def test_get_response_streaming():
import asyncio import asyncio
async def test_async_call(): async def test_async_call():
user_message = "write a short poem in one sentence" user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
litellm.set_verbose = True litellm.set_verbose = True
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5) response = await acompletion(
model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5
)
print(type(response)) print(type(response))
import inspect import inspect
@ -116,29 +153,39 @@ def test_get_response_streaming():
async for chunk in response: async for chunk in response:
token = chunk["choices"][0]["delta"].get("content", "") token = chunk["choices"][0]["delta"].get("content", "")
if token == None: if token == None:
continue # openai v1.0.0 returns content=None continue # openai v1.0.0 returns content=None
output += token output += token
assert output is not None, "output cannot be None." assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str" assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0." assert len(output) > 0, "Length of output needs to be greater than 0."
print(f'output: {output}') print(f"output: {output}")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_streaming() # test_get_response_streaming()
def test_get_response_non_openai_streaming(): def test_get_response_non_openai_streaming():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
litellm.num_retries = 0 litellm.num_retries = 0
async def test_async_call(): async def test_async_call():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True, timeout=5) response = await acompletion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
stream=True,
timeout=5,
)
print(type(response)) print(type(response))
import inspect import inspect
@ -158,11 +205,13 @@ def test_get_response_non_openai_streaming():
assert output is not None, "output cannot be None." assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str" assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0." assert len(output) > 0, "Length of output needs to be greater than 0."
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
return response return response
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_non_openai_streaming()
# test_get_response_non_openai_streaming()

View file

@ -3,79 +3,105 @@
import sys, os, time, inspect, asyncio, traceback import sys, os, time, inspect, asyncio, traceback
from datetime import datetime from datetime import datetime
import pytest import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
import openai, litellm, uuid import openai, litellm, uuid
from openai import AsyncAzureOpenAI from openai import AsyncAzureOpenAI
client = AsyncAzureOpenAI( client = AsyncAzureOpenAI(
api_key=os.getenv("AZURE_API_KEY"), api_key=os.getenv("AZURE_API_KEY"),
azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
api_version=os.getenv("AZURE_API_VERSION") api_version=os.getenv("AZURE_API_VERSION"),
) )
model_list = [ model_list = [
{ {
"model_name": "azure-test", "model_name": "azure-test",
"litellm_params": { "litellm_params": {
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION") "api_version": os.getenv("AZURE_API_VERSION"),
},
} }
}
] ]
router = litellm.Router(model_list=model_list) router = litellm.Router(model_list=model_list)
async def _openai_completion(): async def _openai_completion():
try: try:
start_time = time.time() start_time = time.time()
response = await client.chat.completions.create( response = await client.chat.completions.create(
model="chatgpt-v-2", model="chatgpt-v-2",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
stream=True stream=True,
) )
time_to_first_token = None time_to_first_token = None
first_token_ts = None first_token_ts = None
init_chunk = None init_chunk = None
async for chunk in response: async for chunk in response:
if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: if (
first_token_ts = time.time() time_to_first_token is None
time_to_first_token = first_token_ts - start_time and len(chunk.choices) > 0
init_chunk = chunk and chunk.choices[0].delta.content is not None
end_time = time.time() ):
print("OpenAI Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time) first_token_ts = time.time()
return time_to_first_token time_to_first_token = first_token_ts - start_time
except Exception as e: init_chunk = chunk
print(e) end_time = time.time()
return None print(
"OpenAI Call: ",
init_chunk,
start_time,
first_token_ts,
time_to_first_token,
end_time,
)
return time_to_first_token
except Exception as e:
print(e)
return None
async def _router_completion(): async def _router_completion():
try: try:
start_time = time.time() start_time = time.time()
response = await router.acompletion( response = await router.acompletion(
model="azure-test", model="azure-test",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
stream=True stream=True,
) )
time_to_first_token = None time_to_first_token = None
first_token_ts = None first_token_ts = None
init_chunk = None init_chunk = None
async for chunk in response: async for chunk in response:
if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: if (
first_token_ts = time.time() time_to_first_token is None
time_to_first_token = first_token_ts - start_time and len(chunk.choices) > 0
init_chunk = chunk and chunk.choices[0].delta.content is not None
end_time = time.time() ):
print("Router Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time - first_token_ts) first_token_ts = time.time()
return time_to_first_token time_to_first_token = first_token_ts - start_time
except Exception as e: init_chunk = chunk
print(e) end_time = time.time()
return None print(
"Router Call: ",
init_chunk,
start_time,
first_token_ts,
time_to_first_token,
end_time - first_token_ts,
)
return time_to_first_token
except Exception as e:
print(e)
return None
async def test_azure_completion_streaming():
async def test_azure_completion_streaming():
""" """
Test azure streaming call - measure on time to first (non-null) token. Test azure streaming call - measure on time to first (non-null) token.
""" """
n = 3 # Number of concurrent tasks n = 3 # Number of concurrent tasks
## OPENAI AVG. TIME ## OPENAI AVG. TIME
@ -83,19 +109,20 @@ async def test_azure_completion_streaming():
chat_completions = await asyncio.gather(*tasks) chat_completions = await asyncio.gather(*tasks)
successful_completions = [c for c in chat_completions if c is not None] successful_completions = [c for c in chat_completions if c is not None]
total_time = 0 total_time = 0
for item in successful_completions: for item in successful_completions:
total_time += item total_time += item
avg_openai_time = total_time/3 avg_openai_time = total_time / 3
## ROUTER AVG. TIME ## ROUTER AVG. TIME
tasks = [_router_completion() for _ in range(n)] tasks = [_router_completion() for _ in range(n)]
chat_completions = await asyncio.gather(*tasks) chat_completions = await asyncio.gather(*tasks)
successful_completions = [c for c in chat_completions if c is not None] successful_completions = [c for c in chat_completions if c is not None]
total_time = 0 total_time = 0
for item in successful_completions: for item in successful_completions:
total_time += item total_time += item
avg_router_time = total_time/3 avg_router_time = total_time / 3
## COMPARE ## COMPARE
print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}") print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}")
assert avg_router_time < avg_openai_time + 0.5 assert avg_router_time < avg_openai_time + 0.5
# asyncio.run(test_azure_completion_streaming())
# asyncio.run(test_azure_completion_streaming())

View file

@ -5,6 +5,7 @@
import sys, os import sys, os
import traceback import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
@ -18,6 +19,7 @@ user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
model_val = None model_val = None
def test_completion_with_no_model(): def test_completion_with_no_model():
# test on empty # test on empty
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -32,9 +34,10 @@ def test_completion_with_empty_model():
print(f"error occurred: {e}") print(f"error occurred: {e}")
pass pass
# def test_completion_catch_nlp_exception(): # def test_completion_catch_nlp_exception():
# TEMP commented out NLP cloud API is unstable # TEMP commented out NLP cloud API is unstable
# try: # try:
# response = completion(model="dolphin", messages=messages, functions=[ # response = completion(model="dolphin", messages=messages, functions=[
# { # {
# "name": "get_current_weather", # "name": "get_current_weather",
@ -56,65 +59,77 @@ def test_completion_with_empty_model():
# } # }
# ]) # ])
# except Exception as e: # except Exception as e:
# if "Function calling is not supported by nlp_cloud" in str(e): # if "Function calling is not supported by nlp_cloud" in str(e):
# pass # pass
# else: # else:
# pytest.fail(f'An error occurred {e}') # pytest.fail(f'An error occurred {e}')
# test_completion_catch_nlp_exception() # test_completion_catch_nlp_exception()
def test_completion_invalid_param_cohere(): def test_completion_invalid_param_cohere():
try: try:
response = completion(model="command-nightly", messages=messages, top_p=1) response = completion(model="command-nightly", messages=messages, top_p=1)
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
if "Unsupported parameters passed: top_p" in str(e): if "Unsupported parameters passed: top_p" in str(e):
pass pass
else: else:
pytest.fail(f'An error occurred {e}') pytest.fail(f"An error occurred {e}")
# test_completion_invalid_param_cohere() # test_completion_invalid_param_cohere()
def test_completion_function_call_cohere(): def test_completion_function_call_cohere():
try: try:
response = completion(model="command-nightly", messages=messages, functions=["TEST-FUNCTION"]) response = completion(
pytest.fail(f'An error occurred {e}') model="command-nightly", messages=messages, functions=["TEST-FUNCTION"]
except Exception as e: )
pytest.fail(f"An error occurred {e}")
except Exception as e:
print(e) print(e)
pass pass
# test_completion_function_call_cohere() # test_completion_function_call_cohere()
def test_completion_function_call_openai():
try: def test_completion_function_call_openai():
try:
messages = [{"role": "user", "content": "What is the weather like in Boston?"}] messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
response = completion(model="gpt-3.5-turbo", messages=messages, functions=[ response = completion(
{ model="gpt-3.5-turbo",
"name": "get_current_weather", messages=messages,
"description": "Get the current weather in a given location", functions=[
"parameters": { {
"type": "object", "name": "get_current_weather",
"properties": { "description": "Get the current weather in a given location",
"location": { "parameters": {
"type": "string", "type": "object",
"description": "The city and state, e.g. San Francisco, CA" "properties": {
}, "location": {
"unit": { "type": "string",
"type": "string", "description": "The city and state, e.g. San Francisco, CA",
"enum": ["celsius", "fahrenheit"] },
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
} }
}, ],
"required": ["location"] )
}
}
])
print(f"response: {response}") print(f"response: {response}")
except: except:
pass pass
# test_completion_function_call_openai()
# test_completion_function_call_openai()
def test_completion_with_no_provider(): def test_completion_with_no_provider():
# test on empty # test on empty
@ -125,6 +140,7 @@ def test_completion_with_no_provider():
print(f"error occurred: {e}") print(f"error occurred: {e}")
pass pass
# test_completion_with_no_provider() # test_completion_with_no_provider()
# # bad key # # bad key
# temp_key = os.environ.get("OPENAI_API_KEY") # temp_key = os.environ.get("OPENAI_API_KEY")
@ -136,4 +152,4 @@ def test_completion_with_no_provider():
# except: # except:
# print(f"error occurred: {traceback.format_exc()}") # print(f"error occurred: {traceback.format_exc()}")
# pass # pass
# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5 # os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5

View file

@ -4,62 +4,78 @@
import sys, os import sys, os
import traceback import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from openai import APITimeoutError as Timeout from openai import APITimeoutError as Timeout
import litellm import litellm
litellm.num_retries = 0 litellm.num_retries = 0
from litellm import batch_completion, batch_completion_models, completion, batch_completion_models_all_responses from litellm import (
batch_completion,
batch_completion_models,
completion,
batch_completion_models_all_responses,
)
# litellm.set_verbose=True # litellm.set_verbose=True
def test_batch_completions(): def test_batch_completions():
messages = [[{"role": "user", "content": "write a short poem"}] for _ in range(3)] messages = [[{"role": "user", "content": "write a short poem"}] for _ in range(3)]
model = "j2-mid" model = "j2-mid"
litellm.set_verbose = True litellm.set_verbose = True
try: try:
result = batch_completion( result = batch_completion(
model=model, model=model,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
temperature=0.2, temperature=0.2,
request_timeout=1 request_timeout=1,
) )
print(result) print(result)
print(len(result)) print(len(result))
assert(len(result)==3) assert len(result) == 3
except Timeout as e: except Timeout as e:
print(f"IN TIMEOUT") print(f"IN TIMEOUT")
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An error occurred: {e}") pytest.fail(f"An error occurred: {e}")
test_batch_completions() test_batch_completions()
def test_batch_completions_models(): def test_batch_completions_models():
try: try:
result = batch_completion_models( result = batch_completion_models(
models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"], models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"],
messages=[{"role": "user", "content": "Hey, how's it going"}] messages=[{"role": "user", "content": "Hey, how's it going"}],
) )
print(result) print(result)
except Timeout as e: except Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An error occurred: {e}") pytest.fail(f"An error occurred: {e}")
# test_batch_completions_models() # test_batch_completions_models()
def test_batch_completion_models_all_responses(): def test_batch_completion_models_all_responses():
try: try:
responses = batch_completion_models_all_responses( responses = batch_completion_models_all_responses(
models=["j2-light", "claude-instant-1.2"], models=["j2-light", "claude-instant-1.2"],
messages=[{"role": "user", "content": "write a poem"}], messages=[{"role": "user", "content": "write a poem"}],
max_tokens=10 max_tokens=10,
) )
print(responses) print(responses)
assert(len(responses) == 2) assert len(responses) == 2
except Timeout as e: except Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An error occurred: {e}") pytest.fail(f"An error occurred: {e}")
# test_batch_completion_models_all_responses()
# test_batch_completion_models_all_responses()

View file

@ -3,12 +3,12 @@
# import sys, os, json # import sys, os, json
# import traceback # import traceback
# import pytest # import pytest
# sys.path.insert( # sys.path.insert(
# 0, os.path.abspath("../..") # 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path # ) # Adds the parent directory to the system path
# import litellm # import litellm
# litellm.set_verbose = True # litellm.set_verbose = True
# from litellm import completion, BudgetManager # from litellm import completion, BudgetManager
@ -16,7 +16,7 @@
# ## Scenario 1: User budget enough to make call # ## Scenario 1: User budget enough to make call
# def test_user_budget_enough(): # def test_user_budget_enough():
# try: # try:
# user = "1234" # user = "1234"
# # create a budget for a user # # create a budget for a user
# budget_manager.create_budget(total_budget=10, user=user, duration="daily") # budget_manager.create_budget(total_budget=10, user=user, duration="daily")
@ -38,7 +38,7 @@
# ## Scenario 2: User budget not enough to make call # ## Scenario 2: User budget not enough to make call
# def test_user_budget_not_enough(): # def test_user_budget_not_enough():
# try: # try:
# user = "12345" # user = "12345"
# # create a budget for a user # # create a budget for a user
# budget_manager.create_budget(total_budget=0, user=user, duration="daily") # budget_manager.create_budget(total_budget=0, user=user, duration="daily")
@ -60,7 +60,7 @@
# except: # except:
# pytest.fail(f"An error occurred") # pytest.fail(f"An error occurred")
# ## Scenario 3: Saving budget to client # ## Scenario 3: Saving budget to client
# def test_save_user_budget(): # def test_save_user_budget():
# try: # try:
# response = budget_manager.save_data() # response = budget_manager.save_data()
@ -70,17 +70,17 @@
# except Exception as e: # except Exception as e:
# pytest.fail(f"An error occurred: {str(e)}") # pytest.fail(f"An error occurred: {str(e)}")
# test_save_user_budget() # test_save_user_budget()
# ## Scenario 4: Getting list of users # ## Scenario 4: Getting list of users
# def test_get_users(): # def test_get_users():
# try: # try:
# response = budget_manager.get_users() # response = budget_manager.get_users()
# print(response) # print(response)
# except: # except:
# pytest.fail(f"An error occurred") # pytest.fail(f"An error occurred")
# ## Scenario 5: Reset budget at the end of duration # ## Scenario 5: Reset budget at the end of duration
# def test_reset_on_duration(): # def test_reset_on_duration():
# try: # try:
# # First, set a short duration budget for a user # # First, set a short duration budget for a user
@ -100,7 +100,7 @@
# # Now, we need to simulate the passing of time. Since we don't want our tests to actually take days, we're going # # Now, we need to simulate the passing of time. Since we don't want our tests to actually take days, we're going
# # to cheat a little -- we'll manually adjust the "created_at" time so it seems like a day has passed. # # to cheat a little -- we'll manually adjust the "created_at" time so it seems like a day has passed.
# # In a real-world testing scenario, we might instead use something like the `freezegun` library to mock the system time. # # In a real-world testing scenario, we might instead use something like the `freezegun` library to mock the system time.
# one_day_in_seconds = 24 * 60 * 60 # one_day_in_seconds = 24 * 60 * 60
# budget_manager.user_dict[user]["last_updated_at"] -= one_day_in_seconds # budget_manager.user_dict[user]["last_updated_at"] -= one_day_in_seconds
@ -108,11 +108,11 @@
# budget_manager.update_budget_all_users() # budget_manager.update_budget_all_users()
# # Make sure the budget was actually reset # # Make sure the budget was actually reset
# assert budget_manager.get_current_cost(user) == 0, "Budget didn't reset after duration expired" # assert budget_manager.get_current_cost(user) == 0, "Budget didn't reset after duration expired"
# except Exception as e: # except Exception as e:
# pytest.fail(f"An error occurred - {str(e)}") # pytest.fail(f"An error occurred - {str(e)}")
# ## Scenario 6: passing in text: # ## Scenario 6: passing in text:
# def test_input_text_on_completion(): # def test_input_text_on_completion():
# try: # try:
# user = "12345" # user = "12345"
@ -127,4 +127,4 @@
# except Exception as e: # except Exception as e:
# pytest.fail(f"An error occurred - {str(e)}") # pytest.fail(f"An error occurred - {str(e)}")
# test_input_text_on_completion() # test_input_text_on_completion()

View file

@ -14,6 +14,7 @@ import litellm
from litellm import embedding, completion from litellm import embedding, completion
from litellm.caching import Cache from litellm.caching import Cache
import random import random
# litellm.set_verbose=True # litellm.set_verbose=True
messages = [{"role": "user", "content": "who is ishaan Github? "}] messages = [{"role": "user", "content": "who is ishaan Github? "}]
@ -22,23 +23,30 @@ messages = [{"role": "user", "content": "who is ishaan Github? "}]
import random import random
import string import string
def generate_random_word(length=4): def generate_random_word(length=4):
letters = string.ascii_lowercase letters = string.ascii_lowercase
return ''.join(random.choice(letters) for _ in range(length)) return "".join(random.choice(letters) for _ in range(length))
messages = [{"role": "user", "content": "who is ishaan 5222"}] messages = [{"role": "user", "content": "who is ishaan 5222"}]
def test_caching_v2(): # test in memory cache
def test_caching_v2(): # test in memory cache
try: try:
litellm.set_verbose=True litellm.set_verbose = True
litellm.cache = Cache() litellm.cache = Cache()
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
litellm.cache = None # disable cache litellm.cache = None # disable cache
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
@ -46,12 +54,14 @@ def test_caching_v2(): # test in memory cache
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_caching_v2() # test_caching_v2()
def test_caching_with_models_v2(): def test_caching_with_models_v2():
messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}] messages = [
{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}
]
litellm.cache = Cache() litellm.cache = Cache()
print("test2 for caching") print("test2 for caching")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
@ -63,34 +73,51 @@ def test_caching_with_models_v2():
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: if (
response3["choices"][0]["message"]["content"]
== response2["choices"][0]["message"]["content"]
):
# if models are different, it should not return cached response # if models are different, it should not return cached response
print(f"response2: {response2}") print(f"response2: {response2}")
print(f"response3: {response3}") print(f"response3: {response3}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
):
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
# test_caching_with_models_v2() # test_caching_with_models_v2()
embedding_large_text = """ embedding_large_text = (
"""
small text small text
""" * 5 """
* 5
)
# # test_caching_with_models() # # test_caching_with_models()
def test_embedding_caching(): def test_embedding_caching():
import time import time
litellm.cache = Cache() litellm.cache = Cache()
text_to_embed = [embedding_large_text] text_to_embed = [embedding_large_text]
start_time = time.time() start_time = time.time()
embedding1 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True) embedding1 = embedding(
model="text-embedding-ada-002", input=text_to_embed, caching=True
)
end_time = time.time() end_time = time.time()
print(f"Embedding 1 response time: {end_time - start_time} seconds") print(f"Embedding 1 response time: {end_time - start_time} seconds")
time.sleep(1) time.sleep(1)
start_time = time.time() start_time = time.time()
embedding2 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True) embedding2 = embedding(
model="text-embedding-ada-002", input=text_to_embed, caching=True
)
end_time = time.time() end_time = time.time()
print(f"embedding2: {embedding2}") print(f"embedding2: {embedding2}")
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
@ -98,29 +125,30 @@ def test_embedding_caching():
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]:
print(f"embedding1: {embedding1}") print(f"embedding1: {embedding1}")
print(f"embedding2: {embedding2}") print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed") pytest.fail("Error occurred: Embedding caching failed")
# test_embedding_caching() # test_embedding_caching()
def test_embedding_caching_azure(): def test_embedding_caching_azure():
print("Testing azure embedding caching") print("Testing azure embedding caching")
import time import time
litellm.cache = Cache() litellm.cache = Cache()
text_to_embed = [embedding_large_text] text_to_embed = [embedding_large_text]
api_key = os.environ['AZURE_API_KEY'] api_key = os.environ["AZURE_API_KEY"]
api_base = os.environ['AZURE_API_BASE'] api_base = os.environ["AZURE_API_BASE"]
api_version = os.environ['AZURE_API_VERSION'] api_version = os.environ["AZURE_API_VERSION"]
os.environ['AZURE_API_VERSION'] = ""
os.environ['AZURE_API_BASE'] = ""
os.environ['AZURE_API_KEY'] = ""
os.environ["AZURE_API_VERSION"] = ""
os.environ["AZURE_API_BASE"] = ""
os.environ["AZURE_API_KEY"] = ""
start_time = time.time() start_time = time.time()
print("AZURE CONFIGS") print("AZURE CONFIGS")
@ -133,7 +161,7 @@ def test_embedding_caching_azure():
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
caching=True caching=True,
) )
end_time = time.time() end_time = time.time()
print(f"Embedding 1 response time: {end_time - start_time} seconds") print(f"Embedding 1 response time: {end_time - start_time} seconds")
@ -146,7 +174,7 @@ def test_embedding_caching_azure():
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
caching=True caching=True,
) )
end_time = time.time() end_time = time.time()
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
@ -154,15 +182,16 @@ def test_embedding_caching_azure():
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]:
print(f"embedding1: {embedding1}") print(f"embedding1: {embedding1}")
print(f"embedding2: {embedding2}") print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed") pytest.fail("Error occurred: Embedding caching failed")
os.environ['AZURE_API_VERSION'] = api_version os.environ["AZURE_API_VERSION"] = api_version
os.environ['AZURE_API_BASE'] = api_base os.environ["AZURE_API_BASE"] = api_base
os.environ['AZURE_API_KEY'] = api_key os.environ["AZURE_API_KEY"] = api_key
# test_embedding_caching_azure() # test_embedding_caching_azure()
@ -170,13 +199,28 @@ def test_embedding_caching_azure():
def test_redis_cache_completion(): def test_redis_cache_completion():
litellm.set_verbose = False litellm.set_verbose = False
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache random_number = random.randint(
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}] 1, 100000
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) ) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test2 for caching") print("test2 for caching")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20) response1 = completion(
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20) model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
response3 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5) )
response2 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
)
response3 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5
)
response4 = completion(model="command-nightly", messages=messages, caching=True) response4 = completion(model="command-nightly", messages=messages, caching=True)
print("\nresponse 1", response1) print("\nresponse 1", response1)
@ -192,49 +236,88 @@ def test_redis_cache_completion():
1 & 3 should be different, since input params are diff 1 & 3 should be different, since input params are diff
1 & 4 should be diff, since models are diff 1 & 4 should be diff, since models are diff
""" """
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
): # 1 and 2 should be the same
# 1&2 have the exact same input params. This MUST Be a CACHE HIT # 1&2 have the exact same input params. This MUST Be a CACHE HIT
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
if response1['choices'][0]['message']['content'] == response3['choices'][0]['message']['content']: if (
response1["choices"][0]["message"]["content"]
== response3["choices"][0]["message"]["content"]
):
# if input params like seed, max_tokens are diff it should NOT be a cache hit # if input params like seed, max_tokens are diff it should NOT be a cache hit
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response3: {response3}") print(f"response3: {response3}")
pytest.fail(f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:") pytest.fail(
if response1['choices'][0]['message']['content'] == response4['choices'][0]['message']['content']: f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:"
)
if (
response1["choices"][0]["message"]["content"]
== response4["choices"][0]["message"]["content"]
):
# if models are different, it should not return cached response # if models are different, it should not return cached response
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response4: {response4}") print(f"response4: {response4}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
# test_redis_cache_completion() # test_redis_cache_completion()
def test_redis_cache_completion_stream(): def test_redis_cache_completion_stream():
try: try:
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
litellm.callbacks = [] litellm.callbacks = []
litellm.set_verbose = True litellm.set_verbose = True
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache random_number = random.randint(
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}] 1, 100000
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) ) # add a random number to ensure it's always adding / reading from cache
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion") print("test for caching, streaming + completion")
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True) response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=0.2,
stream=True,
)
response_1_content = "" response_1_content = ""
for chunk in response1: for chunk in response1:
print(chunk) print(chunk)
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
time.sleep(0.5) time.sleep(0.5)
response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True) response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=0.2,
stream=True,
)
response_2_content = "" response_2_content = ""
for chunk in response2: for chunk in response2:
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print("\nresponse 1", response_1_content) print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = [] litellm.success_callback = []
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
@ -247,99 +330,171 @@ def test_redis_cache_completion_stream():
1 & 2 should be exactly the same 1 & 2 should be exactly the same
""" """
# test_redis_cache_completion_stream() # test_redis_cache_completion_stream()
def test_redis_cache_acompletion_stream(): def test_redis_cache_acompletion_stream():
import asyncio import asyncio
try: try:
litellm.set_verbose = True litellm.set_verbose = True
random_word = generate_random_word() random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}] messages = [
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) {
"role": "user",
"content": f"write a one sentence poem about: {random_word}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion") print("test for caching, streaming + completion")
response_1_content = "" response_1_content = ""
response_2_content = "" response_2_content = ""
async def call1(): async def call1():
nonlocal response_1_content nonlocal response_1_content
response1 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True) response1 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1: async for chunk in response1:
print(chunk) print(chunk)
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
asyncio.run(call1()) asyncio.run(call1())
time.sleep(0.5) time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n") print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2(): async def call2():
nonlocal response_2_content nonlocal response_2_content
response2 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True) response2 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2: async for chunk in response2:
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content) print(response_2_content)
asyncio.run(call2()) asyncio.run(call2())
print("\nresponse 1", response_1_content) print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
except Exception as e: except Exception as e:
print(e) print(e)
raise e raise e
# test_redis_cache_acompletion_stream() # test_redis_cache_acompletion_stream()
def test_redis_cache_acompletion_stream_bedrock(): def test_redis_cache_acompletion_stream_bedrock():
import asyncio import asyncio
try: try:
litellm.set_verbose = True litellm.set_verbose = True
random_word = generate_random_word() random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}] messages = [
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) {
"role": "user",
"content": f"write a one sentence poem about: {random_word}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion") print("test for caching, streaming + completion")
response_1_content = "" response_1_content = ""
response_2_content = "" response_2_content = ""
async def call1(): async def call1():
nonlocal response_1_content nonlocal response_1_content
response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True) response1 = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1: async for chunk in response1:
print(chunk) print(chunk)
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
asyncio.run(call1()) asyncio.run(call1())
time.sleep(0.5) time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n") print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2(): async def call2():
nonlocal response_2_content nonlocal response_2_content
response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True) response2 = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2: async for chunk in response2:
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content) print(response_2_content)
asyncio.run(call2()) asyncio.run(call2())
print("\nresponse 1", response_1_content) print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
except Exception as e: except Exception as e:
print(e) print(e)
raise e raise e
# test_redis_cache_acompletion_stream_bedrock() # test_redis_cache_acompletion_stream_bedrock()
# redis cache with custom keys # redis cache with custom keys
def custom_get_cache_key(*args, **kwargs): def custom_get_cache_key(*args, **kwargs):
# return key to use for your cache: # return key to use for your cache:
key = kwargs.get("model", "") + str(kwargs.get("messages", "")) + str(kwargs.get("temperature", "")) + str(kwargs.get("logit_bias", "")) key = (
kwargs.get("model", "")
+ str(kwargs.get("messages", ""))
+ str(kwargs.get("temperature", ""))
+ str(kwargs.get("logit_bias", ""))
)
return key return key
def test_custom_redis_cache_with_key(): def test_custom_redis_cache_with_key():
messages = [{"role": "user", "content": "write a one line story"}] messages = [{"role": "user", "content": "write a one line story"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.cache.get_cache_key = custom_get_cache_key litellm.cache.get_cache_key = custom_get_cache_key
local_cache = {} local_cache = {}
@ -356,54 +511,72 @@ def test_custom_redis_cache_with_key():
# patch this redis cache get and set call # patch this redis cache get and set call
response1 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3) response1 = completion(
response2 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3) model="gpt-3.5-turbo",
response3 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=False, num_retries=3) messages=messages,
temperature=1,
caching=True,
num_retries=3,
)
response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
temperature=1,
caching=True,
num_retries=3,
)
response3 = completion(
model="gpt-3.5-turbo",
messages=messages,
temperature=1,
caching=False,
num_retries=3,
)
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
print(f"response3: {response3}") print(f"response3: {response3}")
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: if (
response3["choices"][0]["message"]["content"]
== response2["choices"][0]["message"]["content"]
):
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
# test_custom_redis_cache_with_key() # test_custom_redis_cache_with_key()
def test_cache_override(): def test_cache_override():
# test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set # test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set
# in this case it should not return cached responses # in this case it should not return cached responses
litellm.cache = Cache() litellm.cache = Cache()
print("Testing cache override") print("Testing cache override")
litellm.set_verbose=True litellm.set_verbose = True
# test embedding # test embedding
response1 = embedding( response1 = embedding(
model = "text-embedding-ada-002", model="text-embedding-ada-002", input=["hello who are you"], caching=False
input=[
"hello who are you"
],
caching = False
) )
start_time = time.time() start_time = time.time()
response2 = embedding( response2 = embedding(
model = "text-embedding-ada-002", model="text-embedding-ada-002", input=["hello who are you"], caching=False
input=[
"hello who are you"
],
caching = False
) )
end_time = time.time() end_time = time.time()
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
assert end_time - start_time > 0.1 # ensure 2nd response comes in over 0.1s. This should not be cached. assert (
# test_cache_override() end_time - start_time > 0.1
) # ensure 2nd response comes in over 0.1s. This should not be cached.
# test_cache_override()
def test_custom_redis_cache_params(): def test_custom_redis_cache_params():
@ -411,17 +584,17 @@ def test_custom_redis_cache_params():
try: try:
litellm.cache = Cache( litellm.cache = Cache(
type="redis", type="redis",
host=os.environ['REDIS_HOST'], host=os.environ["REDIS_HOST"],
port=os.environ['REDIS_PORT'], port=os.environ["REDIS_PORT"],
password=os.environ['REDIS_PASSWORD'], password=os.environ["REDIS_PASSWORD"],
db = 0, db=0,
ssl=True, ssl=True,
ssl_certfile="./redis_user.crt", ssl_certfile="./redis_user.crt",
ssl_keyfile="./redis_user_private.key", ssl_keyfile="./redis_user_private.key",
ssl_ca_certs="./redis_ca.pem", ssl_ca_certs="./redis_ca.pem",
) )
print(litellm.cache.cache.redis_client) print(litellm.cache.cache.redis_client)
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
@ -431,58 +604,126 @@ def test_custom_redis_cache_params():
def test_get_cache_key(): def test_get_cache_key():
from litellm.caching import Cache from litellm.caching import Cache
try: try:
print("Testing get_cache_key") print("Testing get_cache_key")
cache_instance = Cache() cache_instance = Cache()
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}} cache_key = cache_instance.get_cache_key(
**{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "write a one sentence poem about: 7510"}
],
"max_tokens": 40,
"temperature": 0.2,
"stream": True,
"litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b",
"litellm_logging_obj": {},
}
) )
cache_key_2 = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}} cache_key_2 = cache_instance.get_cache_key(
**{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "write a one sentence poem about: 7510"}
],
"max_tokens": 40,
"temperature": 0.2,
"stream": True,
"litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b",
"litellm_logging_obj": {},
}
) )
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40" assert (
assert cache_key == cache_key_2, f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs" cache_key
== "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
)
assert (
cache_key == cache_key_2
), f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
embedding_cache_key = cache_instance.get_cache_key( embedding_cache_key = cache_instance.get_cache_key(
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/', **{
'api_key': '', 'api_version': '2023-07-01-preview', "model": "azure/azure-embedding-model",
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'], "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
'caching': True, "api_key": "",
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>" "api_version": "2023-07-01-preview",
"timeout": None,
"max_retries": 0,
"input": ["hi who is ishaan"],
"caching": True,
"client": "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
} }
) )
print(embedding_cache_key) print(embedding_cache_key)
assert embedding_cache_key == "model: azure/azure-embedding-modelinput: ['hi who is ishaan']", f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs" assert (
embedding_cache_key
== "model: azure/azure-embedding-modelinput: ['hi who is ishaan']"
), f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
# Proxy - embedding cache, test if embedding key, gets model_group and not model # Proxy - embedding cache, test if embedding key, gets model_group and not model
embedding_cache_key_2 = cache_instance.get_cache_key( embedding_cache_key_2 = cache_instance.get_cache_key(
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/', **{
'api_key': '', 'api_version': '2023-07-01-preview', "model": "azure/azure-embedding-model",
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'], "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
'caching': True, "api_key": "",
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>", "api_version": "2023-07-01-preview",
'proxy_server_request': {'url': 'http://0.0.0.0:8000/embeddings', "timeout": None,
'method': 'POST', "max_retries": 0,
'headers': "input": ["hi who is ishaan"],
{'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', "caching": True,
'content-length': '80'}, "client": "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
'body': {'model': 'azure-embedding-model', 'input': ['hi who is ishaan']}}, "proxy_server_request": {
'user': None, "url": "http://0.0.0.0:8000/embeddings",
'metadata': {'user_api_key': None, "method": "POST",
'headers': {'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', 'content-length': '80'}, "headers": {
'model_group': 'EMBEDDING_MODEL_GROUP', "host": "0.0.0.0:8000",
'deployment': 'azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview'}, "user-agent": "curl/7.88.1",
'model_info': {'mode': 'embedding', 'base_model': 'text-embedding-ada-002', 'id': '20b2b515-f151-4dd5-a74f-2231e2f54e29'}, "accept": "*/*",
'litellm_call_id': '2642e009-b3cd-443d-b5dd-bb7d56123b0e', 'litellm_logging_obj': '<litellm.utils.Logging object at 0x12f1bddb0>'} "content-type": "application/json",
"content-length": "80",
},
"body": {
"model": "azure-embedding-model",
"input": ["hi who is ishaan"],
},
},
"user": None,
"metadata": {
"user_api_key": None,
"headers": {
"host": "0.0.0.0:8000",
"user-agent": "curl/7.88.1",
"accept": "*/*",
"content-type": "application/json",
"content-length": "80",
},
"model_group": "EMBEDDING_MODEL_GROUP",
"deployment": "azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview",
},
"model_info": {
"mode": "embedding",
"base_model": "text-embedding-ada-002",
"id": "20b2b515-f151-4dd5-a74f-2231e2f54e29",
},
"litellm_call_id": "2642e009-b3cd-443d-b5dd-bb7d56123b0e",
"litellm_logging_obj": "<litellm.utils.Logging object at 0x12f1bddb0>",
}
) )
print(embedding_cache_key_2) print(embedding_cache_key_2)
assert embedding_cache_key_2 == "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']" assert (
embedding_cache_key_2
== "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
)
print("passed!") print("passed!")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
pytest.fail(f"Error occurred:", e) pytest.fail(f"Error occurred:", e)
test_get_cache_key() test_get_cache_key()
# test_custom_redis_cache_params() # test_custom_redis_cache_params()
@ -581,4 +822,4 @@ test_get_cache_key()
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content'] # assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
# time.sleep(2) # time.sleep(2)
# assert cache.get_cache(cache_key="test_key") is None # assert cache.get_cache(cache_key="test_key") is None
# # test_in_memory_cache_with_ttl() # # test_in_memory_cache_with_ttl()

View file

@ -1,5 +1,5 @@
#### What this tests #### #### What this tests ####
# This tests using caching w/ litellm which requires SSL=True # This tests using caching w/ litellm which requires SSL=True
import sys, os import sys, os
import time import time
@ -18,15 +18,26 @@ from litellm import embedding, completion, Router
from litellm.caching import Cache from litellm.caching import Cache
messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}] messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
def test_caching_v2(): # test in memory cache
def test_caching_v2(): # test in memory cache
try: try:
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2") litellm.cache = Cache(
type="redis",
host="os.environ/REDIS_HOST_2",
port="os.environ/REDIS_PORT_2",
password="os.environ/REDIS_PASSWORD_2",
ssl="os.environ/REDIS_SSL_2",
)
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
litellm.cache = None # disable cache litellm.cache = None # disable cache
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
raise Exception() raise Exception()
@ -34,41 +45,57 @@ def test_caching_v2(): # test in memory cache
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_caching_v2() # test_caching_v2()
def test_caching_router(): def test_caching_router():
""" """
Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit. Test scenario where litellm.cache is set but kwargs("caching") is not. This should still return a cache hit.
""" """
try: try:
model_list = [ model_list = [
{ {
"model_name": "gpt-3.5-turbo", # openai model name "model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
} }
] ]
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2") litellm.cache = Cache(
router = Router(model_list=model_list, type="redis",
routing_strategy="simple-shuffle", host="os.environ/REDIS_HOST_2",
set_verbose=False, port="os.environ/REDIS_PORT_2",
num_retries=1) # type: ignore password="os.environ/REDIS_PASSWORD_2",
ssl="os.environ/REDIS_SSL_2",
)
router = Router(
model_list=model_list,
routing_strategy="simple-shuffle",
set_verbose=False,
num_retries=1,
) # type: ignore
response1 = completion(model="gpt-3.5-turbo", messages=messages) response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages) response2 = completion(model="gpt-3.5-turbo", messages=messages)
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
litellm.cache = None # disable cache litellm.cache = None # disable cache
assert response2['choices'][0]['message']['content'] == response1['choices'][0]['message']['content'] assert (
response2["choices"][0]["message"]["content"]
== response1["choices"][0]["message"]["content"]
)
except Exception as e: except Exception as e:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_caching_router()
# test_caching_router()

View file

@ -8,7 +8,7 @@
# 0, os.path.abspath("../..") # 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path # ) # Adds the parent directory to the system path
# import litellm # import litellm
# import asyncio # import asyncio
# litellm.set_verbose = True # litellm.set_verbose = True
# from litellm import Router # from litellm import Router
@ -18,9 +18,9 @@
# # This enables response_model keyword # # This enables response_model keyword
# # # from client.chat.completions.create # # # from client.chat.completions.create
# # client = instructor.patch(Router(model_list=[{ # # client = instructor.patch(Router(model_list=[{
# # "model_name": "gpt-3.5-turbo", # openai model name # # "model_name": "gpt-3.5-turbo", # openai model name
# # "litellm_params": { # params for litellm completion/embedding call # # "litellm_params": { # params for litellm completion/embedding call
# # "model": "azure/chatgpt-v-2", # # "model": "azure/chatgpt-v-2",
# # "api_key": os.getenv("AZURE_API_KEY"), # # "api_key": os.getenv("AZURE_API_KEY"),
# # "api_version": os.getenv("AZURE_API_VERSION"), # # "api_version": os.getenv("AZURE_API_VERSION"),
# # "api_base": os.getenv("AZURE_API_BASE") # # "api_base": os.getenv("AZURE_API_BASE")
@ -49,9 +49,9 @@
# from openai import AsyncOpenAI # from openai import AsyncOpenAI
# aclient = instructor.apatch(Router(model_list=[{ # aclient = instructor.apatch(Router(model_list=[{
# "model_name": "gpt-3.5-turbo", # openai model name # "model_name": "gpt-3.5-turbo", # openai model name
# "litellm_params": { # params for litellm completion/embedding call # "litellm_params": { # params for litellm completion/embedding call
# "model": "azure/chatgpt-v-2", # "model": "azure/chatgpt-v-2",
# "api_key": os.getenv("AZURE_API_KEY"), # "api_key": os.getenv("AZURE_API_KEY"),
# "api_version": os.getenv("AZURE_API_VERSION"), # "api_version": os.getenv("AZURE_API_VERSION"),
# "api_base": os.getenv("AZURE_API_BASE") # "api_base": os.getenv("AZURE_API_BASE")
@ -71,4 +71,4 @@
# ) # )
# print(f"model: {model}") # print(f"model: {model}")
# asyncio.run(main()) # asyncio.run(main())

File diff suppressed because it is too large Load diff

View file

@ -28,6 +28,7 @@ def logger_fn(user_model_dict):
# print(f"user_model_dict: {user_model_dict}") # print(f"user_model_dict: {user_model_dict}")
pass pass
# normal call # normal call
def test_completion_custom_provider_model_name(): def test_completion_custom_provider_model_name():
try: try:
@ -41,25 +42,31 @@ def test_completion_custom_provider_model_name():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# completion with num retries + impact on exception mapping # completion with num retries + impact on exception mapping
def test_completion_with_num_retries(): def test_completion_with_num_retries():
try: try:
response = completion(model="j2-ultra", messages=[{"messages": "vibe", "bad": "message"}], num_retries=2) response = completion(
model="j2-ultra",
messages=[{"messages": "vibe", "bad": "message"}],
num_retries=2,
)
pytest.fail(f"Unmapped exception occurred") pytest.fail(f"Unmapped exception occurred")
except Exception as e: except Exception as e:
pass pass
# test_completion_with_num_retries() # test_completion_with_num_retries()
def test_completion_with_0_num_retries(): def test_completion_with_0_num_retries():
try: try:
litellm.set_verbose=False litellm.set_verbose = False
print("making request") print("making request")
# Use the completion function # Use the completion function
response = completion( response = completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"gm": "vibe", "role": "user"}], messages=[{"gm": "vibe", "role": "user"}],
max_retries=4 max_retries=4,
) )
print(response) print(response)
@ -69,5 +76,6 @@ def test_completion_with_0_num_retries():
print("exception", e) print("exception", e)
pass pass
# Call the test function # Call the test function
test_completion_with_0_num_retries() test_completion_with_0_num_retries()

View file

@ -15,77 +15,104 @@ from litellm import completion_with_config
config = { config = {
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"], "default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
"model": { "model": {
"claude-instant-1": { "claude-instant-1": {"needs_moderation": True},
"needs_moderation": True
},
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
"error_handling": { "error_handling": {
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"} "ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
} }
} },
} },
} }
def test_config_context_window_exceeded(): def test_config_context_window_exceeded():
try: try:
sample_text = "how does a court case get to the Supreme Court?" * 1000 sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config) response = completion_with_config(
model="gpt-3.5-turbo", messages=messages, config=config
)
print(response) print(response)
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_config_context_window_exceeded()
# test_config_context_window_exceeded()
def test_config_context_moderation(): def test_config_context_moderation():
try: try:
messages=[{"role": "user", "content": "I want to kill them."}] messages = [{"role": "user", "content": "I want to kill them."}]
response = completion_with_config(model="claude-instant-1", messages=messages, config=config) response = completion_with_config(
model="claude-instant-1", messages=messages, config=config
)
print(response) print(response)
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_config_context_moderation()
# test_config_context_moderation()
def test_config_context_default_fallback(): def test_config_context_default_fallback():
try: try:
messages=[{"role": "user", "content": "Hey, how's it going?"}] messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = completion_with_config(model="claude-instant-1", messages=messages, config=config, api_key="bad-key") response = completion_with_config(
model="claude-instant-1",
messages=messages,
config=config,
api_key="bad-key",
)
print(response) print(response)
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_config_context_default_fallback()
# test_config_context_default_fallback()
config = { config = {
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"], "default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
"available_models": ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-0314", "gpt-4-0613", "available_models": [
"j2-ultra", "command-nightly", "togethercomputer/llama-2-70b-chat", "chat-bison", "chat-bison@001", "claude-2"], "gpt-3.5-turbo",
"adapt_to_prompt_size": True, # type: ignore "gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"j2-ultra",
"command-nightly",
"togethercomputer/llama-2-70b-chat",
"chat-bison",
"chat-bison@001",
"claude-2",
],
"adapt_to_prompt_size": True, # type: ignore
"model": { "model": {
"claude-instant-1": { "claude-instant-1": {"needs_moderation": True},
"needs_moderation": True
},
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
"error_handling": { "error_handling": {
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"} "ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
} }
} },
} },
} }
def test_config_context_adapt_to_prompt(): def test_config_context_adapt_to_prompt():
try: try:
sample_text = "how does a court case get to the Supreme Court?" * 1000 sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config) response = completion_with_config(
model="gpt-3.5-turbo", messages=messages, config=config
)
print(response) print(response)
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
test_config_context_adapt_to_prompt()
test_config_context_adapt_to_prompt()

View file

@ -1,14 +1,16 @@
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from fastapi import Request from fastapi import Request
from dotenv import load_dotenv from dotenv import load_dotenv
import os import os
load_dotenv() load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
print(f"api_key: {api_key}") print(f"api_key: {api_key}")
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234": if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
return UserAPIKeyAuth(api_key=api_key) return UserAPIKeyAuth(api_key=api_key)
raise Exception raise Exception
except: except:
raise Exception raise Exception

View file

@ -2,30 +2,35 @@ from litellm.integrations.custom_logger import CustomLogger
import inspect import inspect
import litellm import litellm
class testCustomCallbackProxy(CustomLogger): class testCustomCallbackProxy(CustomLogger):
def __init__(self): def __init__(self):
self.success: bool = False # type: ignore self.success: bool = False # type: ignore
self.failure: bool = False # type: ignore self.failure: bool = False # type: ignore
self.async_success: bool = False # type: ignore self.async_success: bool = False # type: ignore
self.async_success_embedding: bool = False # type: ignore self.async_success_embedding: bool = False # type: ignore
self.async_failure: bool = False # type: ignore self.async_failure: bool = False # type: ignore
self.async_failure_embedding: bool = False # type: ignore self.async_failure_embedding: bool = False # type: ignore
self.async_completion_kwargs = None # type: ignore self.async_completion_kwargs = None # type: ignore
self.async_embedding_kwargs = None # type: ignore self.async_embedding_kwargs = None # type: ignore
self.async_embedding_response = None # type: ignore self.async_embedding_response = None # type: ignore
self.async_completion_kwargs_fail = None # type: ignore self.async_completion_kwargs_fail = None # type: ignore
self.async_embedding_kwargs_fail = None # type: ignore self.async_embedding_kwargs_fail = None # type: ignore
self.streaming_response_obj = None # type: ignore self.streaming_response_obj = None # type: ignore
blue_color_code = "\033[94m" blue_color_code = "\033[94m"
reset_color_code = "\033[0m" reset_color_code = "\033[0m"
print(f"{blue_color_code}Initialized LiteLLM custom logger") print(f"{blue_color_code}Initialized LiteLLM custom logger")
try: try:
print(f"Logger Initialized with following methods:") print(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))] methods = [
method
for method in dir(self)
if inspect.ismethod(getattr(self, method))
]
# Pretty print the methods # Pretty print the methods
for method in methods: for method in methods:
print(f" - {method}") print(f" - {method}")
@ -33,29 +38,32 @@ class testCustomCallbackProxy(CustomLogger):
except: except:
pass pass
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call") print(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call") print(f"Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time): def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream") print(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success") print(f"On Success")
self.success = True self.success = True
def log_failure_event(self, kwargs, response_obj, start_time, end_time): def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure") print(f"On Failure")
self.failure = True self.failure = True
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async success") print(f"On Async success")
self.async_success = True self.async_success = True
print("Value of async success: ", self.async_success) print("Value of async success: ", self.async_success)
print("\n kwargs: ", kwargs) print("\n kwargs: ", kwargs)
if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada": if (
kwargs.get("model") == "azure-embedding-model"
or kwargs.get("model") == "ada"
):
print("Got an embedding model", kwargs.get("model")) print("Got an embedding model", kwargs.get("model"))
print("Setting embedding success to True") print("Setting embedding success to True")
self.async_success_embedding = True self.async_success_embedding = True
@ -65,7 +73,6 @@ class testCustomCallbackProxy(CustomLogger):
if kwargs.get("stream") == True: if kwargs.get("stream") == True:
self.streaming_response_obj = response_obj self.streaming_response_obj = response_obj
self.async_completion_kwargs = kwargs self.async_completion_kwargs = kwargs
model = kwargs.get("model", None) model = kwargs.get("model", None)
@ -74,17 +81,18 @@ class testCustomCallbackProxy(CustomLogger):
# Access litellm_params passed to litellm.completion(), example access `metadata` # Access litellm_params passed to litellm.completion(), example access `metadata`
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here metadata = litellm_params.get(
"metadata", {}
) # headers passed to LiteLLM proxy, can be found here
# Calculate cost using litellm.completion_cost() # Calculate cost using litellm.completion_cost()
cost = litellm.completion_cost(completion_response=response_obj) cost = litellm.completion_cost(completion_response=response_obj)
response = response_obj response = response_obj
# tokens used in response # tokens used in response
usage = response_obj["usage"] usage = response_obj["usage"]
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger)) print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
print( print(
f""" f"""
Model: {model}, Model: {model},
@ -98,8 +106,7 @@ class testCustomCallbackProxy(CustomLogger):
) )
return return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure") print(f"On Async Failure")
self.async_failure = True self.async_failure = True
print("Value of async failure: ", self.async_failure) print("Value of async failure: ", self.async_failure)
@ -107,7 +114,8 @@ class testCustomCallbackProxy(CustomLogger):
if kwargs.get("model") == "text-embedding-ada-002": if kwargs.get("model") == "text-embedding-ada-002":
self.async_failure_embedding = True self.async_failure_embedding = True
self.async_embedding_kwargs_fail = kwargs self.async_embedding_kwargs_fail = kwargs
self.async_completion_kwargs_fail = kwargs self.async_completion_kwargs_fail = kwargs
my_custom_logger = testCustomCallbackProxy()
my_custom_logger = testCustomCallbackProxy()

File diff suppressed because it is too large Load diff

View file

@ -3,7 +3,8 @@
import sys, os, time, inspect, asyncio, traceback import sys, os, time, inspect, asyncio, traceback
from datetime import datetime from datetime import datetime
import pytest import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List from typing import Optional, Literal, List
from litellm import Router, Cache from litellm import Router, Cache
import litellm import litellm
@ -14,206 +15,274 @@ from litellm.integrations.custom_logger import CustomLogger
## 2: Post-API-Call ## 2: Post-API-Call
## 3: On LiteLLM Call success ## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure ## 4: On LiteLLM Call failure
## fallbacks ## fallbacks
## retries ## retries
# Test cases # Test cases
## 1. Simple Azure OpenAI acompletion + streaming call ## 1. Simple Azure OpenAI acompletion + streaming call
## 2. Simple Azure OpenAI aembedding call ## 2. Simple Azure OpenAI aembedding call
## 3. Azure OpenAI acompletion + streaming call with retries ## 3. Azure OpenAI acompletion + streaming call with retries
## 4. Azure OpenAI aembedding call with retries ## 4. Azure OpenAI aembedding call with retries
## 5. Azure OpenAI acompletion + streaming call with fallbacks ## 5. Azure OpenAI acompletion + streaming call with fallbacks
## 6. Azure OpenAI aembedding call with fallbacks ## 6. Azure OpenAI aembedding call with fallbacks
# Test interfaces # Test interfaces
## 1. router.completion() + router.embeddings() ## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings ## 2. proxy.completions + proxy.embeddings
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
class CompletionCustomHandler(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
""" """
The set of expected inputs to a custom handler for a The set of expected inputs to a custom handler for a
""" """
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
self.errors = [] self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = [] self.states: Optional[
List[
Literal[
"sync_pre_api_call",
"async_pre_api_call",
"post_api_call",
"sync_stream",
"async_stream",
"sync_success",
"async_success",
"sync_failure",
"async_failure",
]
]
] = []
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
try: try:
print(f'received kwargs in pre-input: {kwargs}') print(f"received kwargs in pre-input: {kwargs}")
self.states.append("sync_pre_api_call") self.states.append("sync_pre_api_call")
## MODEL ## MODEL
assert isinstance(model, str) assert isinstance(model, str)
## MESSAGES ## MESSAGES
assert isinstance(messages, list) assert isinstance(messages, list)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
### ROUTER-SPECIFIC KWARGS ### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict) assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict) assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None))) assert isinstance(
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None))) kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict) assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try: try:
self.states.append("post_api_call") self.states.append("post_api_call")
## START TIME ## START TIME
assert isinstance(start_time, datetime) assert isinstance(start_time, datetime)
## END TIME ## END TIME
assert end_time == None assert end_time == None
## RESPONSE OBJECT ## RESPONSE OBJECT
assert response_obj == None assert response_obj == None
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str)) assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response']) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(
assert isinstance(kwargs['log_event_type'], str) kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.iscoroutine(kwargs["original_response"])
or inspect.isasyncgen(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
### ROUTER-SPECIFIC KWARGS ### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict) assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict) assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None))) assert isinstance(
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None))) kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict) assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.states.append("async_stream") self.states.append("async_stream")
## START TIME ## START TIME
assert isinstance(start_time, datetime) assert isinstance(start_time, datetime)
## END TIME ## END TIME
assert isinstance(end_time, datetime) assert isinstance(end_time, datetime)
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse) assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) assert isinstance(kwargs["messages"], list) and isinstance(
assert isinstance(kwargs['optional_params'], dict) kwargs["messages"][0], dict
assert isinstance(kwargs['litellm_params'], dict) )
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(kwargs["input"], list)
assert isinstance(kwargs['log_event_type'], str) and isinstance(kwargs["input"][0], dict)
except: ) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.states.append("sync_success") self.states.append("sync_success")
## START TIME ## START TIME
assert isinstance(start_time, datetime) assert isinstance(start_time, datetime)
## END TIME ## END TIME
assert isinstance(end_time, datetime) assert isinstance(end_time, datetime)
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse) assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) assert isinstance(kwargs["messages"], list) and isinstance(
assert isinstance(kwargs['optional_params'], dict) kwargs["messages"][0], dict
assert isinstance(kwargs['litellm_params'], dict) )
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(kwargs["input"], list)
assert isinstance(kwargs['log_event_type'], str) and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time): def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.states.append("sync_failure") self.states.append("sync_failure")
## START TIME ## START TIME
assert isinstance(start_time, datetime) assert isinstance(start_time, datetime)
## END TIME ## END TIME
assert isinstance(end_time, datetime) assert isinstance(end_time, datetime)
## RESPONSE OBJECT ## RESPONSE OBJECT
assert response_obj == None assert response_obj == None
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) assert isinstance(kwargs["messages"], list) and isinstance(
assert isinstance(kwargs['optional_params'], dict) kwargs["messages"][0], dict
assert isinstance(kwargs['litellm_params'], dict) )
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(kwargs["input"], list)
assert isinstance(kwargs['log_event_type'], str) and isinstance(kwargs["input"][0], dict)
except: ) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs): async def async_log_pre_api_call(self, model, messages, kwargs):
try: try:
""" """
No-op. No-op.
Not implemented yet. Not implemented yet.
""" """
pass pass
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.states.append("async_success") self.states.append("async_success")
## START TIME ## START TIME
assert isinstance(start_time, datetime) assert isinstance(start_time, datetime)
## END TIME ## END TIME
assert isinstance(end_time, datetime) assert isinstance(end_time, datetime)
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)) assert isinstance(
response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)
)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str)) assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(
assert isinstance(kwargs['log_event_type'], str) kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
### ROUTER-SPECIFIC KWARGS ### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict) assert isinstance(kwargs["litellm_params"]["metadata"], dict)
@ -221,10 +290,14 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict) assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None))) assert isinstance(
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None))) kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict) assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -232,257 +305,281 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
try: try:
print(f"received original response: {kwargs['original_response']}") print(f"received original response: {kwargs['original_response']}")
self.states.append("async_failure") self.states.append("async_failure")
## START TIME ## START TIME
assert isinstance(start_time, datetime) assert isinstance(start_time, datetime)
## END TIME ## END TIME
assert isinstance(end_time, datetime) assert isinstance(end_time, datetime)
## RESPONSE OBJECT ## RESPONSE OBJECT
assert response_obj == None assert response_obj == None
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['input'], (list, str, dict)) assert isinstance(kwargs["input"], (list, str, dict))
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(
assert isinstance(kwargs['log_event_type'], str) kwargs["original_response"], (str, litellm.CustomStreamWrapper)
except: )
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
# Simple Azure OpenAI call
# Simple Azure OpenAI call
## COMPLETION ## COMPLETION
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_chat_azure(): async def test_async_chat_azure():
try: try:
customHandler_completion_azure_router = CompletionCustomHandler() customHandler_completion_azure_router = CompletionCustomHandler()
customHandler_streaming_azure_router = CompletionCustomHandler() customHandler_streaming_azure_router = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_completion_azure_router] litellm.callbacks = [customHandler_completion_azure_router]
model_list = [ model_list = [
{ {
"model_name": "gpt-3.5-turbo", # openai model name "model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
] ]
router = Router(model_list=model_list) # type: ignore router = Router(model_list=model_list) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo", response = await router.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" )
}])
await asyncio.sleep(2) await asyncio.sleep(2)
assert len(customHandler_completion_azure_router.errors) == 0 assert len(customHandler_completion_azure_router.errors) == 0
assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success assert (
# streaming len(customHandler_completion_azure_router.states) == 3
) # pre, post, success
# streaming
litellm.callbacks = [customHandler_streaming_azure_router] litellm.callbacks = [customHandler_streaming_azure_router]
router2 = Router(model_list=model_list) # type: ignore router2 = Router(model_list=model_list) # type: ignore
response = await router2.acompletion(model="gpt-3.5-turbo", response = await router2.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" stream=True,
}], )
stream=True) async for chunk in response:
async for chunk in response:
print(f"async azure router chunk: {chunk}") print(f"async azure router chunk: {chunk}")
continue continue
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_streaming_azure_router.states}") print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
assert len(customHandler_streaming_azure_router.errors) == 0 assert len(customHandler_streaming_azure_router.errors) == 0
assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success assert (
# failure len(customHandler_streaming_azure_router.states) >= 4
) # pre, post, stream (multiple times), success
# failure
model_list = [ model_list = [
{ {
"model_name": "gpt-3.5-turbo", # openai model name "model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": "my-bad-key", "api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
] ]
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore router3 = Router(model_list=model_list) # type: ignore
try: try:
response = await router3.acompletion(model="gpt-3.5-turbo", response = await router3.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" )
}])
print(f"response in router3 acompletion: {response}") print(f"response in router3 acompletion: {response}")
except: except:
pass pass
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_failure.states}") print(f"customHandler.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0 assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure assert len(customHandler_failure.states) == 3 # pre, post, failure
assert "async_failure" in customHandler_failure.states assert "async_failure" in customHandler_failure.states
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure()) # asyncio.run(test_async_chat_azure())
## EMBEDDING ## EMBEDDING
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_embedding_azure(): async def test_async_embedding_azure():
try: try:
customHandler = CompletionCustomHandler() customHandler = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
model_list = [ model_list = [
{ {
"model_name": "azure-embedding-model", # openai model name "model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model", "model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
] ]
router = Router(model_list=model_list) # type: ignore router = Router(model_list=model_list) # type: ignore
response = await router.aembedding(model="azure-embedding-model", response = await router.aembedding(
input=["hello from litellm!"]) model="azure-embedding-model", input=["hello from litellm!"]
)
await asyncio.sleep(2) await asyncio.sleep(2)
assert len(customHandler.errors) == 0 assert len(customHandler.errors) == 0
assert len(customHandler.states) == 3 # pre, post, success assert len(customHandler.states) == 3 # pre, post, success
# failure # failure
model_list = [ model_list = [
{ {
"model_name": "azure-embedding-model", # openai model name "model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model", "model": "azure/azure-embedding-model",
"api_key": "my-bad-key", "api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
] ]
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore router3 = Router(model_list=model_list) # type: ignore
try: try:
response = await router3.aembedding(model="azure-embedding-model", response = await router3.aembedding(
input=["hello from litellm!"]) model="azure-embedding-model", input=["hello from litellm!"]
)
print(f"response in router3 aembedding: {response}") print(f"response in router3 aembedding: {response}")
except: except:
pass pass
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_failure.states}") print(f"customHandler.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0 assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure assert len(customHandler_failure.states) == 3 # pre, post, failure
assert "async_failure" in customHandler_failure.states assert "async_failure" in customHandler_failure.states
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_embedding_azure()) # asyncio.run(test_async_embedding_azure())
# Azure OpenAI call w/ Fallbacks # Azure OpenAI call w/ Fallbacks
## COMPLETION ## COMPLETION
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_chat_azure_with_fallbacks(): async def test_async_chat_azure_with_fallbacks():
try: try:
customHandler_fallbacks = CompletionCustomHandler() customHandler_fallbacks = CompletionCustomHandler()
litellm.callbacks = [customHandler_fallbacks] litellm.callbacks = [customHandler_fallbacks]
# with fallbacks # with fallbacks
model_list = [ model_list = [
{ {
"model_name": "gpt-3.5-turbo", # openai model name "model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": "my-bad-key", "api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
{ {
"model_name": "gpt-3.5-turbo-16k", "model_name": "gpt-3.5-turbo-16k",
"litellm_params": { "litellm_params": {
"model": "gpt-3.5-turbo-16k", "model": "gpt-3.5-turbo-16k",
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
} },
] ]
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo", response = await router.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" )
}])
await asyncio.sleep(2) await asyncio.sleep(2)
print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}") print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}")
assert len(customHandler_fallbacks.errors) == 0 assert len(customHandler_fallbacks.errors) == 0
assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success assert (
len(customHandler_fallbacks.states) == 6
) # pre, post, failure, pre, post, success
litellm.callbacks = [] litellm.callbacks = []
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure_with_fallbacks()) # asyncio.run(test_async_chat_azure_with_fallbacks())
# CACHING
# CACHING
## Test Azure - completion, embedding ## Test Azure - completion, embedding
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_completion_azure_caching(): async def test_async_completion_azure_caching():
customHandler_caching = CompletionCustomHandler() customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching] litellm.callbacks = [customHandler_caching]
unique_time = time.time() unique_time = time.time()
model_list = [ model_list = [
{ {
"model_name": "gpt-3.5-turbo", # openai model name "model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
{ {
"model_name": "gpt-3.5-turbo-16k", "model_name": "gpt-3.5-turbo-16k",
"litellm_params": { "litellm_params": {
"model": "gpt-3.5-turbo-16k", "model": "gpt-3.5-turbo-16k",
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
} },
] ]
router = Router(model_list=model_list) # type: ignore router = Router(model_list=model_list) # type: ignore
response1 = await router.acompletion(model="gpt-3.5-turbo", response1 = await router.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
"content": f"Hi 👋 - i'm async azure {unique_time}" caching=True,
}], )
caching=True)
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}") print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await router.acompletion(model="gpt-3.5-turbo", response2 = await router.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
"content": f"Hi 👋 - i'm async azure {unique_time}" caching=True,
}], )
caching=True) await asyncio.sleep(1) # success callbacks are done in parallel
await asyncio.sleep(1) # success callbacks are done in parallel print(
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}") f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success assert len(customHandler_caching.states) == 4 # pre, post, success, success

View file

@ -1,12 +1,14 @@
import sys import sys
import os import os
import io, asyncio import io, asyncio
# import logging # import logging
# logging.basicConfig(level=logging.DEBUG) # logging.basicConfig(level=logging.DEBUG)
sys.path.insert(0, os.path.abspath('../..')) sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion from litellm import completion
import litellm import litellm
litellm.num_retries = 3 litellm.num_retries = 3
import time, random import time, random
@ -29,11 +31,14 @@ def pre_request():
import re import re
def verify_log_file(log_file_path):
with open(log_file_path, 'r') as log_file:
def verify_log_file(log_file_path):
with open(log_file_path, "r") as log_file:
log_content = log_file.read() log_content = log_file.read()
print(f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content) print(
f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content
)
# Define the pattern to search for in the log file # Define the pattern to search for in the log file
pattern = r"Response from DynamoDB:{.*?}" pattern = r"Response from DynamoDB:{.*?}"
@ -50,17 +55,21 @@ def verify_log_file(log_file_path):
print(f"Total occurrences of specified response: {len(matches)}") print(f"Total occurrences of specified response: {len(matches)}")
# Count the occurrences of successful responses (status code 200 or 201) # Count the occurrences of successful responses (status code 200 or 201)
success_count = sum(1 for match in matches if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match) success_count = sum(
1
for match in matches
if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match
)
# Print the count of successful responses # Print the count of successful responses
print(f"Count of successful responses from DynamoDB: {success_count}") print(f"Count of successful responses from DynamoDB: {success_count}")
assert success_count == 3 # Expect 3 success logs from dynamoDB assert success_count == 3 # Expect 3 success logs from dynamoDB
def test_dynamo_logging(): def test_dynamo_logging():
# all dynamodb requests need to be in one test function # all dynamodb requests need to be in one test function
# since we are modifying stdout, and pytests runs tests in parallel # since we are modifying stdout, and pytests runs tests in parallel
try: try:
# pre # pre
# redirect stdout to log_file # redirect stdout to log_file
@ -69,44 +78,44 @@ def test_dynamo_logging():
litellm.set_verbose = True litellm.set_verbose = True
original_stdout, log_file, file_name = pre_request() original_stdout, log_file, file_name = pre_request()
print("Testing async dynamoDB logging") print("Testing async dynamoDB logging")
async def _test(): async def _test():
return await litellm.acompletion( return await litellm.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}], messages=[{"role": "user", "content": "This is a test"}],
max_tokens=100, max_tokens=100,
temperature=0.7, temperature=0.7,
user = "ishaan-2" user="ishaan-2",
) )
response = asyncio.run(_test()) response = asyncio.run(_test())
print(f"response: {response}") print(f"response: {response}")
# streaming + async
# streaming + async
async def _test2(): async def _test2():
response = await litellm.acompletion( response = await litellm.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}], messages=[{"role": "user", "content": "This is a test"}],
max_tokens=10, max_tokens=10,
temperature=0.7, temperature=0.7,
user = "ishaan-2", user="ishaan-2",
stream=True stream=True,
) )
async for chunk in response: async for chunk in response:
pass pass
asyncio.run(_test2()) asyncio.run(_test2())
# aembedding() # aembedding()
async def _test3(): async def _test3():
return await litellm.aembedding( return await litellm.aembedding(
model="text-embedding-ada-002", model="text-embedding-ada-002", input=["hi"], user="ishaan-2"
input = ["hi"],
user = "ishaan-2"
) )
response = asyncio.run(_test3()) response = asyncio.run(_test3())
time.sleep(1) time.sleep(1)
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {e}") pytest.fail(f"An exception occurred - {e}")
finally: finally:
# post, close log file and verify # post, close log file and verify
@ -117,4 +126,5 @@ def test_dynamo_logging():
verify_log_file(file_name) verify_log_file(file_name)
print("Passed! Testing async dynamoDB logging") print("Passed! Testing async dynamoDB logging")
# test_dynamo_logging_async() # test_dynamo_logging_async()

View file

@ -14,39 +14,49 @@ from litellm import embedding, completion
litellm.set_verbose = False litellm.set_verbose = False
def test_openai_embedding(): def test_openai_embedding():
try: try:
litellm.set_verbose=True litellm.set_verbose = True
response = embedding( response = embedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
metadata = {"anything": "good day"} metadata={"anything": "good day"},
) )
litellm_response = dict(response) litellm_response = dict(response)
litellm_response_keys = set(litellm_response.keys()) litellm_response_keys = set(litellm_response.keys())
litellm_response_keys.discard('_response_ms') litellm_response_keys.discard("_response_ms")
print(litellm_response_keys) print(litellm_response_keys)
print("LiteLLM Response\n") print("LiteLLM Response\n")
# print(litellm_response) # print(litellm_response)
# same request with OpenAI 1.0+ # same request with OpenAI 1.0+
import openai import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'])
client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
response = client.embeddings.create( response = client.embeddings.create(
model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"] model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"],
) )
response = dict(response) response = dict(response)
openai_response_keys = set(response.keys()) openai_response_keys = set(response.keys())
print(openai_response_keys) print(openai_response_keys)
assert litellm_response_keys == openai_response_keys # ENSURE the Keys in litellm response is exactly what the openai package returns assert (
assert len(litellm_response["data"]) == 2 # expect two embedding responses from litellm_response since input had two litellm_response_keys == openai_response_keys
) # ENSURE the Keys in litellm response is exactly what the openai package returns
assert (
len(litellm_response["data"]) == 2
) # expect two embedding responses from litellm_response since input had two
print(openai_response_keys) print(openai_response_keys)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_openai_embedding() # test_openai_embedding()
def test_openai_azure_embedding_simple(): def test_openai_azure_embedding_simple():
try: try:
response = embedding( response = embedding(
@ -55,12 +65,15 @@ def test_openai_azure_embedding_simple():
) )
print(response) print(response)
response_keys = set(dict(response).keys()) response_keys = set(dict(response).keys())
response_keys.discard('_response_ms') response_keys.discard("_response_ms")
assert set(["usage", "model", "object", "data"]) == set(response_keys) #assert litellm response has expected keys from OpenAI embedding response assert set(["usage", "model", "object", "data"]) == set(
response_keys
) # assert litellm response has expected keys from OpenAI embedding response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_openai_azure_embedding_simple() # test_openai_azure_embedding_simple()
@ -69,41 +82,50 @@ def test_openai_azure_embedding_timeouts():
response = embedding( response = embedding(
model="azure/azure-embedding-model", model="azure/azure-embedding-model",
input=["good morning from litellm"], input=["good morning from litellm"],
timeout=0.00001 timeout=0.00001,
) )
print(response) print(response)
except openai.APITimeoutError: except openai.APITimeoutError:
print("Good job got timeout error!") print("Good job got timeout error!")
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}") pytest.fail(
f"Expected timeout error, did not get the correct error. Instead got {e}"
)
# test_openai_azure_embedding_timeouts() # test_openai_azure_embedding_timeouts()
def test_openai_embedding_timeouts(): def test_openai_embedding_timeouts():
try: try:
response = embedding( response = embedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=["good morning from litellm"], input=["good morning from litellm"],
timeout=0.00001 timeout=0.00001,
) )
print(response) print(response)
except openai.APITimeoutError: except openai.APITimeoutError:
print("Good job got OpenAI timeout error!") print("Good job got OpenAI timeout error!")
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}") pytest.fail(
f"Expected timeout error, did not get the correct error. Instead got {e}"
)
# test_openai_embedding_timeouts() # test_openai_embedding_timeouts()
def test_openai_azure_embedding(): def test_openai_azure_embedding():
try: try:
api_key = os.environ['AZURE_API_KEY'] api_key = os.environ["AZURE_API_KEY"]
api_base = os.environ['AZURE_API_BASE'] api_base = os.environ["AZURE_API_BASE"]
api_version = os.environ['AZURE_API_VERSION'] api_version = os.environ["AZURE_API_VERSION"]
os.environ['AZURE_API_VERSION'] = "" os.environ["AZURE_API_VERSION"] = ""
os.environ['AZURE_API_BASE'] = "" os.environ["AZURE_API_BASE"] = ""
os.environ['AZURE_API_KEY'] = "" os.environ["AZURE_API_KEY"] = ""
response = embedding( response = embedding(
model="azure/azure-embedding-model", model="azure/azure-embedding-model",
@ -114,137 +136,179 @@ def test_openai_azure_embedding():
) )
print(response) print(response)
os.environ["AZURE_API_VERSION"] = api_version
os.environ['AZURE_API_VERSION'] = api_version os.environ["AZURE_API_BASE"] = api_base
os.environ['AZURE_API_BASE'] = api_base os.environ["AZURE_API_KEY"] = api_key
os.environ['AZURE_API_KEY'] = api_key
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_openai_azure_embedding() # test_openai_azure_embedding()
# test_openai_embedding() # test_openai_embedding()
def test_cohere_embedding(): def test_cohere_embedding():
try: try:
# litellm.set_verbose=True # litellm.set_verbose=True
response = embedding( response = embedding(
model="embed-english-v2.0", input=["good morning from litellm", "this is another item"] model="embed-english-v2.0",
input=["good morning from litellm", "this is another item"],
) )
print(f"response:", response) print(f"response:", response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_cohere_embedding() # test_cohere_embedding()
def test_cohere_embedding3(): def test_cohere_embedding3():
try: try:
litellm.set_verbose=True litellm.set_verbose = True
response = embedding( response = embedding(
model="embed-english-v3.0", model="embed-english-v3.0",
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
) )
print(f"response:", response) print(f"response:", response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_cohere_embedding3() # test_cohere_embedding3()
def test_bedrock_embedding_titan(): def test_bedrock_embedding_titan():
try: try:
litellm.set_verbose=True litellm.set_verbose = True
response = embedding( response = embedding(
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data", model="amazon.titan-embed-text-v1",
"lets test a second string for good measure"] input=[
"good morning from litellm, attempting to embed data",
"lets test a second string for good measure",
],
) )
print(f"response:", response) print(f"response:", response)
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list" assert isinstance(
print(f"type of first embedding:", type(response['data'][0]['embedding'][0])) response["data"][0]["embedding"], list
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats" ), "Expected response to be a list"
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
assert all(
isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats"
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_bedrock_embedding_titan() test_bedrock_embedding_titan()
def test_bedrock_embedding_cohere(): def test_bedrock_embedding_cohere():
try: try:
litellm.set_verbose=False litellm.set_verbose = False
response = embedding( response = embedding(
model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"], model="cohere.embed-multilingual-v3",
aws_region_name="os.environ/AWS_REGION_NAME_2" input=[
"good morning from litellm, attempting to embed data",
"lets test a second string for good measure",
],
aws_region_name="os.environ/AWS_REGION_NAME_2",
) )
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list" assert isinstance(
print(f"type of first embedding:", type(response['data'][0]['embedding'][0])) response["data"][0]["embedding"], list
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats" ), "Expected response to be a list"
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
assert all(
isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats"
# print(f"response:", response) # print(f"response:", response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_bedrock_embedding_cohere() # test_bedrock_embedding_cohere()
# comment out hf tests - since hf endpoints are unstable # comment out hf tests - since hf endpoints are unstable
def test_hf_embedding(): def test_hf_embedding():
try: try:
# huggingface/microsoft/codebert-base # huggingface/microsoft/codebert-base
# huggingface/facebook/bart-large # huggingface/facebook/bart-large
response = embedding( response = embedding(
model="huggingface/sentence-transformers/all-MiniLM-L6-v2", input=["good morning from litellm", "this is another item"] model="huggingface/sentence-transformers/all-MiniLM-L6-v2",
input=["good morning from litellm", "this is another item"],
) )
print(f"response:", response) print(f"response:", response)
except Exception as e: except Exception as e:
# Note: Huggingface inference API is unstable and fails with "model loading errors all the time" # Note: Huggingface inference API is unstable and fails with "model loading errors all the time"
pass pass
# test_hf_embedding() # test_hf_embedding()
# test async embeddings # test async embeddings
def test_aembedding(): def test_aembedding():
try: try:
import asyncio import asyncio
async def embedding_call(): async def embedding_call():
try: try:
response = await litellm.aembedding( response = await litellm.aembedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"] input=["good morning from litellm", "this is another item"],
) )
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
asyncio.run(embedding_call()) asyncio.run(embedding_call())
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_aembedding() # test_aembedding()
def test_aembedding_azure(): def test_aembedding_azure():
try: try:
import asyncio import asyncio
async def embedding_call(): async def embedding_call():
try: try:
response = await litellm.aembedding( response = await litellm.aembedding(
model="azure/azure-embedding-model", model="azure/azure-embedding-model",
input=["good morning from litellm", "this is another item"] input=["good morning from litellm", "this is another item"],
) )
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
asyncio.run(embedding_call()) asyncio.run(embedding_call())
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_aembedding_azure() # test_aembedding_azure()
def test_sagemaker_embeddings():
try: def test_sagemaker_embeddings():
response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"]) try:
response = litellm.embedding(
model="sagemaker/berri-benchmarking-gpt-j-6b-fp16",
input=["good morning from litellm", "this is another item"],
)
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_sagemaker_embeddings() # test_sagemaker_embeddings()
# def local_proxy_embeddings(): # def local_proxy_embeddings():
# litellm.set_verbose=True # litellm.set_verbose=True
# response = embedding( # response = embedding(
# model="openai/custom_embedding", # model="openai/custom_embedding",
# input=["good morning from litellm"], # input=["good morning from litellm"],
# api_base="http://0.0.0.0:8000/" # api_base="http://0.0.0.0:8000/"
# ) # )

View file

@ -11,17 +11,18 @@ import litellm
from litellm import ( from litellm import (
embedding, embedding,
completion, completion,
# AuthenticationError, # AuthenticationError,
ContextWindowExceededError, ContextWindowExceededError,
# RateLimitError, # RateLimitError,
# ServiceUnavailableError, # ServiceUnavailableError,
# OpenAIError, # OpenAIError,
) )
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import pytest import pytest
litellm.vertex_project = "pathrise-convert-1606954137718" litellm.vertex_project = "pathrise-convert-1606954137718"
litellm.vertex_location = "us-central1" litellm.vertex_location = "us-central1"
litellm.num_retries=0 litellm.num_retries = 0
# litellm.failure_callback = ["sentry"] # litellm.failure_callback = ["sentry"]
#### What this tests #### #### What this tests ####
@ -36,7 +37,8 @@ litellm.num_retries=0
models = ["command-nightly"] models = ["command-nightly"]
# Test 1: Context Window Errors
# Test 1: Context Window Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window(model): def test_context_window(model):
print("Testing context window error") print("Testing context window error")
@ -52,17 +54,27 @@ def test_context_window(model):
print(f"Worked!") print(f"Worked!")
except RateLimitError: except RateLimitError:
print("RateLimited!") print("RateLimited!")
except Exception as e: except Exception as e:
print(f"{e}") print(f"{e}")
pytest.fail(f"An error occcurred - {e}") pytest.fail(f"An error occcurred - {e}")
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window_with_fallbacks(model): def test_context_window_with_fallbacks(model):
ctx_window_fallback_dict = {"command-nightly": "claude-2", "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k", "azure/chatgpt-v-2": "gpt-3.5-turbo-16k"} ctx_window_fallback_dict = {
"command-nightly": "claude-2",
"gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k",
"azure/chatgpt-v-2": "gpt-3.5-turbo-16k",
}
sample_text = "how does a court case get to the Supreme Court?" * 1000 sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict) completion(
model=model,
messages=messages,
context_window_fallback_dict=ctx_window_fallback_dict,
)
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model) # test_context_window(model=model)
@ -98,7 +110,9 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["AI21_API_KEY"] = "bad-key" os.environ["AI21_API_KEY"] = "bad-key"
elif "togethercomputer" in model: elif "togethercomputer" in model:
temporary_key = os.environ["TOGETHERAI_API_KEY"] temporary_key = os.environ["TOGETHERAI_API_KEY"]
os.environ["TOGETHERAI_API_KEY"] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a" os.environ[
"TOGETHERAI_API_KEY"
] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a"
elif model in litellm.openrouter_models: elif model in litellm.openrouter_models:
temporary_key = os.environ["OPENROUTER_API_KEY"] temporary_key = os.environ["OPENROUTER_API_KEY"]
os.environ["OPENROUTER_API_KEY"] = "bad-key" os.environ["OPENROUTER_API_KEY"] = "bad-key"
@ -115,9 +129,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
temporary_key = os.environ["REPLICATE_API_KEY"] temporary_key = os.environ["REPLICATE_API_KEY"]
os.environ["REPLICATE_API_KEY"] = "bad-key" os.environ["REPLICATE_API_KEY"] = "bad-key"
print(f"model: {model}") print(f"model: {model}")
response = completion( response = completion(model=model, messages=messages)
model=model, messages=messages
)
print(f"response: {response}") print(f"response: {response}")
except AuthenticationError as e: except AuthenticationError as e:
print(f"AuthenticationError Caught Exception - {str(e)}") print(f"AuthenticationError Caught Exception - {str(e)}")
@ -148,23 +160,25 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["REPLICATE_API_KEY"] = temporary_key os.environ["REPLICATE_API_KEY"] = temporary_key
elif "j2" in model: elif "j2" in model:
os.environ["AI21_API_KEY"] = temporary_key os.environ["AI21_API_KEY"] = temporary_key
elif ("togethercomputer" in model): elif "togethercomputer" in model:
os.environ["TOGETHERAI_API_KEY"] = temporary_key os.environ["TOGETHERAI_API_KEY"] = temporary_key
elif model in litellm.aleph_alpha_models: elif model in litellm.aleph_alpha_models:
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
elif model in litellm.nlp_cloud_models: elif model in litellm.nlp_cloud_models:
os.environ["NLP_CLOUD_API_KEY"] = temporary_key os.environ["NLP_CLOUD_API_KEY"] = temporary_key
elif "bedrock" in model: elif "bedrock" in model:
os.environ["AWS_ACCESS_KEY_ID"] = temporary_aws_access_key os.environ["AWS_ACCESS_KEY_ID"] = temporary_aws_access_key
os.environ["AWS_REGION_NAME"] = temporary_aws_region_name os.environ["AWS_REGION_NAME"] = temporary_aws_region_name
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
return return
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# invalid_auth(model=model) # invalid_auth(model=model)
# invalid_auth(model="command-nightly") # invalid_auth(model="command-nightly")
# Test 3: Invalid Request Error
# Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_invalid_request_error(model): def test_invalid_request_error(model):
messages = [{"content": "hey, how's it going?", "role": "user"}] messages = [{"content": "hey, how's it going?", "role": "user"}]
@ -173,23 +187,18 @@ def test_invalid_request_error(model):
completion(model=model, messages=messages, max_tokens="hello world") completion(model=model, messages=messages, max_tokens="hello world")
def test_completion_azure_exception(): def test_completion_azure_exception():
try: try:
import openai import openai
print("azure gpt-3.5 test\n\n") print("azure gpt-3.5 test\n\n")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
old_azure_key = os.environ["AZURE_API_KEY"] old_azure_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "good morning" os.environ["AZURE_API_KEY"] = "good morning"
response = completion( response = completion(
model="azure/chatgpt-v-2", model="azure/chatgpt-v-2",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
os.environ["AZURE_API_KEY"] = old_azure_key os.environ["AZURE_API_KEY"] = old_azure_key
print(f"response: {response}") print(f"response: {response}")
@ -199,25 +208,24 @@ def test_completion_azure_exception():
print("good job got the correct error for azure when key not set") print("good job got the correct error for azure when key not set")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_azure_exception() # test_completion_azure_exception()
async def asynctest_completion_azure_exception(): async def asynctest_completion_azure_exception():
try: try:
import openai import openai
import litellm import litellm
print("azure gpt-3.5 test\n\n") print("azure gpt-3.5 test\n\n")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
old_azure_key = os.environ["AZURE_API_KEY"] old_azure_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "good morning" os.environ["AZURE_API_KEY"] = "good morning"
response = await litellm.acompletion( response = await litellm.acompletion(
model="azure/chatgpt-v-2", model="azure/chatgpt-v-2",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
print(f"response: {response}") print(f"response: {response}")
print(response) print(response)
@ -229,6 +237,8 @@ async def asynctest_completion_azure_exception():
print("Got wrong exception") print("Got wrong exception")
print("exception", e) print("exception", e)
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# import asyncio # import asyncio
# asyncio.run( # asyncio.run(
# asynctest_completion_azure_exception() # asynctest_completion_azure_exception()
@ -239,19 +249,17 @@ def asynctest_completion_openai_exception_bad_model():
try: try:
import openai import openai
import litellm, asyncio import litellm, asyncio
print("azure exception bad model\n\n") print("azure exception bad model\n\n")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
async def test(): async def test():
response = await litellm.acompletion( response = await litellm.acompletion(
model="openai/gpt-6", model="openai/gpt-6",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
asyncio.run(test()) asyncio.run(test())
except openai.NotFoundError: except openai.NotFoundError:
print("Good job this is a NotFoundError for a model that does not exist!") print("Good job this is a NotFoundError for a model that does not exist!")
@ -261,27 +269,25 @@ def asynctest_completion_openai_exception_bad_model():
assert isinstance(e, openai.BadRequestError) assert isinstance(e, openai.BadRequestError)
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# asynctest_completion_openai_exception_bad_model()
# asynctest_completion_openai_exception_bad_model()
def asynctest_completion_azure_exception_bad_model(): def asynctest_completion_azure_exception_bad_model():
try: try:
import openai import openai
import litellm, asyncio import litellm, asyncio
print("azure exception bad model\n\n") print("azure exception bad model\n\n")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
async def test(): async def test():
response = await litellm.acompletion( response = await litellm.acompletion(
model="azure/gpt-12", model="azure/gpt-12",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
asyncio.run(test()) asyncio.run(test())
except openai.NotFoundError: except openai.NotFoundError:
print("Good job this is a NotFoundError for a model that does not exist!") print("Good job this is a NotFoundError for a model that does not exist!")
@ -290,25 +296,23 @@ def asynctest_completion_azure_exception_bad_model():
print("Raised wrong type of exception", type(e)) print("Raised wrong type of exception", type(e))
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# asynctest_completion_azure_exception_bad_model() # asynctest_completion_azure_exception_bad_model()
def test_completion_openai_exception(): def test_completion_openai_exception():
# test if openai:gpt raises openai.AuthenticationError # test if openai:gpt raises openai.AuthenticationError
try: try:
import openai import openai
print("openai gpt-3.5 test\n\n") print("openai gpt-3.5 test\n\n")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
old_azure_key = os.environ["OPENAI_API_KEY"] old_azure_key = os.environ["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = "good morning" os.environ["OPENAI_API_KEY"] = "good morning"
response = completion( response = completion(
model="gpt-4", model="gpt-4",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
print(f"response: {response}") print(f"response: {response}")
print(response) print(response)
@ -317,25 +321,24 @@ def test_completion_openai_exception():
print("OpenAI: good job got the correct error for openai when key not set") print("OpenAI: good job got the correct error for openai when key not set")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_openai_exception() # test_completion_openai_exception()
def test_completion_mistral_exception(): def test_completion_mistral_exception():
# test if mistral/mistral-tiny raises openai.AuthenticationError # test if mistral/mistral-tiny raises openai.AuthenticationError
try: try:
import openai import openai
print("Testing mistral ai exception mapping") print("Testing mistral ai exception mapping")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
old_azure_key = os.environ["MISTRAL_API_KEY"] old_azure_key = os.environ["MISTRAL_API_KEY"]
os.environ["MISTRAL_API_KEY"] = "good morning" os.environ["MISTRAL_API_KEY"] = "good morning"
response = completion( response = completion(
model="mistral/mistral-tiny", model="mistral/mistral-tiny",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
print(f"response: {response}") print(f"response: {response}")
print(response) print(response)
@ -344,11 +347,11 @@ def test_completion_mistral_exception():
print("good job got the correct error for openai when key not set") print("good job got the correct error for openai when key not set")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_mistral_exception() # test_completion_mistral_exception()
# # test_invalid_request_error(model="command-nightly") # # test_invalid_request_error(model="command-nightly")
# # Test 3: Rate Limit Errors # # Test 3: Rate Limit Errors
# def test_model_call(model): # def test_model_call(model):
@ -387,4 +390,4 @@ def test_completion_mistral_exception():
# counts[result] += 1 # counts[result] += 1
# accuracy_score = counts[True]/(counts[True] + counts[False]) # accuracy_score = counts[True]/(counts[True] + counts[False])
# print(f"accuracy_score: {accuracy_score}") # print(f"accuracy_score: {accuracy_score}")

Some files were not shown because too many files have changed in this diff Show more