refactor: add black formatting

This commit is contained in:
Krrish Dholakia 2023-12-25 14:10:38 +05:30
parent b87d630b0a
commit 4905929de3
156 changed files with 19723 additions and 10869 deletions

View file

@ -1,9 +1,13 @@
repos: repos:
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: 3.8.4 # The version of flake8 to use rev: 3.8.4 # The version of flake8 to use
hooks: hooks:
- id: flake8 - id: flake8
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/ exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/|^litellm/proxy/tests/
additional_dependencies: [flake8-print] additional_dependencies: [flake8-print]
files: litellm/.*\.py files: litellm/.*\.py
- repo: local - repo: local

View file

@ -9,33 +9,37 @@ import os
# Define the list of models to benchmark # Define the list of models to benchmark
# select any LLM listed here: https://docs.litellm.ai/docs/providers # select any LLM listed here: https://docs.litellm.ai/docs/providers
models = ['gpt-3.5-turbo', 'claude-2'] models = ["gpt-3.5-turbo", "claude-2"]
# Enter LLM API keys # Enter LLM API keys
# https://docs.litellm.ai/docs/providers # https://docs.litellm.ai/docs/providers
os.environ['OPENAI_API_KEY'] = "" os.environ["OPENAI_API_KEY"] = ""
os.environ['ANTHROPIC_API_KEY'] = "" os.environ["ANTHROPIC_API_KEY"] = ""
# List of questions to benchmark (replace with your questions) # List of questions to benchmark (replace with your questions)
questions = [ questions = ["When will BerriAI IPO?", "When will LiteLLM hit $100M ARR?"]
"When will BerriAI IPO?",
"When will LiteLLM hit $100M ARR?"
]
# Enter your system prompt here # Enter your system prompt here
system_prompt = """ system_prompt = """
You are LiteLLMs helpful assistant You are LiteLLMs helpful assistant
""" """
@click.command() @click.command()
@click.option('--system-prompt', default="You are a helpful assistant that can answer questions.", help="System prompt for the conversation.") @click.option(
"--system-prompt",
default="You are a helpful assistant that can answer questions.",
help="System prompt for the conversation.",
)
def main(system_prompt): def main(system_prompt):
for question in questions: for question in questions:
data = [] # Data for the current question data = [] # Data for the current question
with tqdm(total=len(models)) as pbar: with tqdm(total=len(models)) as pbar:
for model in models: for model in models:
colored_description = colored(f"Running question: {question} for model: {model}", 'green') colored_description = colored(
f"Running question: {question} for model: {model}", "green"
)
pbar.set_description(colored_description) pbar.set_description(colored_description)
start_time = time.time() start_time = time.time()
@ -44,35 +48,43 @@ def main(system_prompt):
max_tokens=500, max_tokens=500,
messages=[ messages=[
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": question} {"role": "user", "content": question},
], ],
) )
end = time.time() end = time.time()
total_time = end - start_time total_time = end - start_time
cost = completion_cost(completion_response=response) cost = completion_cost(completion_response=response)
raw_response = response['choices'][0]['message']['content'] raw_response = response["choices"][0]["message"]["content"]
data.append({ data.append(
'Model': colored(model, 'light_blue'), {
'Response': raw_response, # Colorize the response "Model": colored(model, "light_blue"),
'ResponseTime': colored(f"{total_time:.2f} seconds", "red"), "Response": raw_response, # Colorize the response
'Cost': colored(f"${cost:.6f}", 'green'), # Colorize the cost "ResponseTime": colored(f"{total_time:.2f} seconds", "red"),
}) "Cost": colored(f"${cost:.6f}", "green"), # Colorize the cost
}
)
pbar.update(1) pbar.update(1)
# Separate headers from the data # Separate headers from the data
headers = ['Model', 'Response', 'Response Time (seconds)', 'Cost ($)'] headers = ["Model", "Response", "Response Time (seconds)", "Cost ($)"]
colwidths = [15, 80, 15, 10] colwidths = [15, 80, 15, 10]
# Create a nicely formatted table for the current question # Create a nicely formatted table for the current question
table = tabulate([list(d.values()) for d in data], headers, tablefmt="grid", maxcolwidths=colwidths) table = tabulate(
[list(d.values()) for d in data],
headers,
tablefmt="grid",
maxcolwidths=colwidths,
)
# Print the table for the current question # Print the table for the current question
colored_question = colored(question, 'green') colored_question = colored(question, "green")
click.echo(f"\nBenchmark Results for '{colored_question}':") click.echo(f"\nBenchmark Results for '{colored_question}':")
click.echo(table) # Display the formatted table click.echo(table) # Display the formatted table
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View file

@ -1,25 +1,22 @@
import sys, os import sys, os
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import litellm import litellm
from litellm import embedding, completion, completion_cost from litellm import embedding, completion, completion_cost
from autoevals.llm import * from autoevals.llm import *
################### ###################
import litellm import litellm
# litellm completion call # litellm completion call
question = "which country has the highest population" question = "which country has the highest population"
response = litellm.completion( response = litellm.completion(
model = "gpt-3.5-turbo", model="gpt-3.5-turbo",
messages = [ messages=[{"role": "user", "content": question}],
{
"role": "user",
"content": question
}
],
) )
print(response) print(response)
# use the auto eval Factuality() evaluator # use the auto eval Factuality() evaluator
@ -27,9 +24,11 @@ print(response)
print("calling evaluator") print("calling evaluator")
evaluator = Factuality() evaluator = Factuality()
result = evaluator( result = evaluator(
output=response.choices[0]["message"]["content"], # response from litellm.completion() output=response.choices[0]["message"][
"content"
], # response from litellm.completion()
expected="India", # expected output expected="India", # expected output
input=question # question passed to litellm.completion input=question, # question passed to litellm.completion
) )
print(result) print(result)

View file

@ -7,6 +7,7 @@ from util import handle_error
from litellm import completion from litellm import completion
import os, dotenv, time import os, dotenv, time
import json import json
dotenv.load_dotenv() dotenv.load_dotenv()
# TODO: set your keys in .env or here: # TODO: set your keys in .env or here:
@ -19,47 +20,61 @@ verbose = True
# litellm.caching_with_models = True # CACHING: caching_with_models Keys in the cache are messages + model. - to learn more: https://docs.litellm.ai/docs/caching/ # litellm.caching_with_models = True # CACHING: caching_with_models Keys in the cache are messages + model. - to learn more: https://docs.litellm.ai/docs/caching/
######### PROMPT LOGGING ########## ######### PROMPT LOGGING ##########
os.environ["PROMPTLAYER_API_KEY"] = "" # set your promptlayer key here - https://promptlayer.com/ os.environ[
"PROMPTLAYER_API_KEY"
] = "" # set your promptlayer key here - https://promptlayer.com/
# set callbacks # set callbacks
litellm.success_callback = ["promptlayer"] litellm.success_callback = ["promptlayer"]
############ HELPER FUNCTIONS ################################### ############ HELPER FUNCTIONS ###################################
def print_verbose(print_statement): def print_verbose(print_statement):
if verbose: if verbose:
print(print_statement) print(print_statement)
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
@app.route('/')
@app.route("/")
def index(): def index():
return 'received!', 200 return "received!", 200
def data_generator(response): def data_generator(response):
for chunk in response: for chunk in response:
yield f"data: {json.dumps(chunk)}\n\n" yield f"data: {json.dumps(chunk)}\n\n"
@app.route('/chat/completions', methods=["POST"])
@app.route("/chat/completions", methods=["POST"])
def api_completion(): def api_completion():
data = request.json data = request.json
start_time = time.time() start_time = time.time()
if data.get('stream') == "True": if data.get("stream") == "True":
data['stream'] = True # convert to boolean data["stream"] = True # convert to boolean
try: try:
if "prompt" not in data: if "prompt" not in data:
raise ValueError("data needs to have prompt") raise ValueError("data needs to have prompt")
data["model"] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct data[
"model"
] = "togethercomputer/CodeLlama-34b-Instruct" # by default use Together AI's CodeLlama model - https://api.together.xyz/playground/chat?model=togethercomputer%2FCodeLlama-34b-Instruct
# COMPLETION CALL # COMPLETION CALL
system_prompt = "Only respond to questions about code. Say 'I don't know' to anything outside of that." system_prompt = "Only respond to questions about code. Say 'I don't know' to anything outside of that."
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": data.pop("prompt")}] messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": data.pop("prompt")},
]
data["messages"] = messages data["messages"] = messages
print(f"data: {data}") print(f"data: {data}")
response = completion(**data) response = completion(**data)
## LOG SUCCESS ## LOG SUCCESS
end_time = time.time() end_time = time.time()
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses if (
return Response(data_generator(response), mimetype='text/event-stream') "stream" in data and data["stream"] == True
): # use generate_responses to stream responses
return Response(data_generator(response), mimetype="text/event-stream")
except Exception as e: except Exception as e:
# call handle_error function # call handle_error function
print_verbose(f"Got Error api_completion(): {traceback.format_exc()}") print_verbose(f"Got Error api_completion(): {traceback.format_exc()}")
@ -69,7 +84,8 @@ def api_completion():
return handle_error(data=data) return handle_error(data=data)
return response return response
@app.route('/get_models', methods=["POST"])
@app.route("/get_models", methods=["POST"])
def get_models(): def get_models():
try: try:
return litellm.model_list return litellm.model_list
@ -78,7 +94,8 @@ def get_models():
response = {"error": str(e)} response = {"error": str(e)}
return response, 200 return response, 200
if __name__ == "__main__": if __name__ == "__main__":
from waitress import serve from waitress import serve
serve(app, host="0.0.0.0", port=4000, threads=500)
serve(app, host="0.0.0.0", port=4000, threads=500)

View file

@ -8,17 +8,18 @@ def get_next_url(response):
:param response: response from requests :param response: response from requests
:return: next url or None :return: next url or None
""" """
if 'link' not in response.headers: if "link" not in response.headers:
return None return None
headers = response.headers headers = response.headers
next_url = headers['Link'] next_url = headers["Link"]
print(next_url) print(next_url)
start_index = next_url.find("<") start_index = next_url.find("<")
end_index = next_url.find(">") end_index = next_url.find(">")
return next_url[1:end_index] return next_url[1:end_index]
def get_models(url): def get_models(url):
""" """
Function to retrieve all models from paginated endpoint Function to retrieve all models from paginated endpoint
@ -36,6 +37,7 @@ def get_models(url):
models.extend(payload) models.extend(payload)
return models return models
def get_cleaned_models(models): def get_cleaned_models(models):
""" """
Function to clean retrieved models Function to clean retrieved models
@ -47,8 +49,9 @@ def get_cleaned_models(models):
cleaned_models.append(model["id"]) cleaned_models.append(model["id"])
return cleaned_models return cleaned_models
# Get text-generation models # Get text-generation models
url = 'https://huggingface.co/api/models?filter=text-generation-inference' url = "https://huggingface.co/api/models?filter=text-generation-inference"
text_generation_models = get_models(url) text_generation_models = get_models(url)
cleaned_text_generation_models = get_cleaned_models(text_generation_models) cleaned_text_generation_models = get_cleaned_models(text_generation_models)
@ -56,7 +59,7 @@ print(cleaned_text_generation_models)
# Get conversational models # Get conversational models
url = 'https://huggingface.co/api/models?filter=conversational' url = "https://huggingface.co/api/models?filter=conversational"
conversational_models = get_models(url) conversational_models = get_models(url)
cleaned_conversational_models = get_cleaned_models(conversational_models) cleaned_conversational_models = get_cleaned_models(conversational_models)
@ -69,15 +72,19 @@ def write_to_txt(cleaned_models, filename):
:param cleaned_models: list of cleaned models :param cleaned_models: list of cleaned models
:param filename: name of the text file :param filename: name of the text file
""" """
with open(filename, 'w') as f: with open(filename, "w") as f:
for item in cleaned_models: for item in cleaned_models:
f.write("%s\n" % item) f.write("%s\n" % item)
# Write contents of cleaned_text_generation_models to text_generation_models.txt # Write contents of cleaned_text_generation_models to text_generation_models.txt
write_to_txt(cleaned_text_generation_models, 'huggingface_llms_metadata/hf_text_generation_models.txt') write_to_txt(
cleaned_text_generation_models,
"huggingface_llms_metadata/hf_text_generation_models.txt",
)
# Write contents of cleaned_conversational_models to conversational_models.txt # Write contents of cleaned_conversational_models to conversational_models.txt
write_to_txt(cleaned_conversational_models, 'huggingface_llms_metadata/hf_conversational_models.txt') write_to_txt(
cleaned_conversational_models,
"huggingface_llms_metadata/hf_conversational_models.txt",
)

View file

@ -1,4 +1,3 @@
import openai import openai
api_base = f"http://0.0.0.0:8000" api_base = f"http://0.0.0.0:8000"
@ -8,29 +7,29 @@ openai.api_key = "temp-key"
print(openai.api_base) print(openai.api_base)
print(f'LiteLLM: response from proxy with streaming') print(f"LiteLLM: response from proxy with streaming")
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model="ollama/llama2", model="ollama/llama2",
messages = [ messages=[
{ {
"role": "user", "role": "user",
"content": "this is a test request, acknowledge that you got it" "content": "this is a test request, acknowledge that you got it",
} }
], ],
stream=True stream=True,
) )
for chunk in response: for chunk in response:
print(f'LiteLLM: streaming response from proxy {chunk}') print(f"LiteLLM: streaming response from proxy {chunk}")
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model="ollama/llama2", model="ollama/llama2",
messages = [ messages=[
{ {
"role": "user", "role": "user",
"content": "this is a test request, acknowledge that you got it" "content": "this is a test request, acknowledge that you got it",
} }
] ],
) )
print(f'LiteLLM: response from proxy {response}') print(f"LiteLLM: response from proxy {response}")

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router from litellm import Router
import litellm import litellm
litellm.set_verbose=False
litellm.set_verbose = False
os.environ.pop("AZURE_AD_TOKEN") os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias "model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name "model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
} },
}, { },
{
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling", "model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
} },
}, { },
{
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
} },
}] },
]
router = Router(model_list=model_list) router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"] file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = [] questions = []
for file_path in file_paths: for file_path in file_paths:
try: try:
print(file_path) print(file_path)
with open(file_path, 'r') as file: with open(file_path, "r") as file:
content = file.read() content = file.read()
questions.append(content) questions.append(content)
except FileNotFoundError as e: except FileNotFoundError as e:
@ -59,7 +68,6 @@ for file_path in file_paths:
# print(q) # print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere # Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions # show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
@ -74,10 +82,18 @@ def make_openai_completion(question):
try: try:
start_time = time.time() start_time = time.time()
import openai import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000"
) # base_url="http://0.0.0.0:8000",
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}], messages=[
{
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
}
],
) )
print(response) print(response)
end_time = time.time() end_time = time.time()
@ -92,11 +108,10 @@ def make_openai_completion(question):
except Exception as e: except Exception as e:
# Log exceptions for failed calls # Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file: with open("error_log.txt", "a") as error_log_file:
error_log_file.write( error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
return None return None
# Number of concurrent calls (you can adjust this) # Number of concurrent calls (you can adjust this)
concurrent_calls = 100 concurrent_calls = 100
@ -133,4 +148,3 @@ with open("request_log.txt", "r") as log_file:
with open("error_log.txt", "r") as error_log_file: with open("error_log.txt", "r") as error_log_file:
print("\nError Log:\n", error_log_file.read()) print("\nError Log:\n", error_log_file.read())

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router from litellm import Router
import litellm import litellm
litellm.set_verbose=False
litellm.set_verbose = False
# os.environ.pop("AZURE_AD_TOKEN") # os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias "model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name "model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
} },
}, { },
{
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling", "model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
} },
}, { },
{
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
} },
}] },
]
router = Router(model_list=model_list) router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"] file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = [] questions = []
for file_path in file_paths: for file_path in file_paths:
try: try:
print(file_path) print(file_path)
with open(file_path, 'r') as file: with open(file_path, "r") as file:
content = file.read() content = file.read()
questions.append(content) questions.append(content)
except FileNotFoundError as e: except FileNotFoundError as e:
@ -59,7 +68,6 @@ for file_path in file_paths:
# print(q) # print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere # Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions # show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
@ -76,9 +84,12 @@ def make_openai_completion(question):
import requests import requests
data = { data = {
'model': 'gpt-3.5-turbo', "model": "gpt-3.5-turbo",
'messages': [ "messages": [
{'role': 'system', 'content': f'You are a helpful assistant. Answer this question{question}'}, {
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
},
], ],
} }
response = requests.post("http://0.0.0.0:8000/queue/request", json=data) response = requests.post("http://0.0.0.0:8000/queue/request", json=data)
@ -107,7 +118,9 @@ def make_openai_completion(question):
) )
break break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") print(
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}"
)
time.sleep(0.5) time.sleep(0.5)
except Exception as e: except Exception as e:
print("got exception in polling", e) print("got exception in polling", e)
@ -117,11 +130,10 @@ def make_openai_completion(question):
except Exception as e: except Exception as e:
# Log exceptions for failed calls # Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file: with open("error_log.txt", "a") as error_log_file:
error_log_file.write( error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
return None return None
# Number of concurrent calls (you can adjust this) # Number of concurrent calls (you can adjust this)
concurrent_calls = 10 concurrent_calls = 10
@ -152,4 +164,3 @@ print(f"Load test Summary:")
print(f"Total Requests: {concurrent_calls}") print(f"Total Requests: {concurrent_calls}")
print(f"Successful Calls: {successful_calls}") print(f"Successful Calls: {successful_calls}")
print(f"Failed Calls: {failed_calls}") print(f"Failed Calls: {failed_calls}")

View file

@ -12,42 +12,51 @@ import pytest
from litellm import Router from litellm import Router
import litellm import litellm
litellm.set_verbose=False
litellm.set_verbose = False
os.environ.pop("AZURE_AD_TOKEN") os.environ.pop("AZURE_AD_TOKEN")
model_list = [{ # list of model deployments model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # model alias "model_name": "gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", # actual model name "model": "azure/chatgpt-v-2", # actual model name
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
} },
}, { },
{
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-functioncalling", "model": "azure/chatgpt-functioncalling",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
} },
}, { },
{
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
} },
}] },
]
router = Router(model_list=model_list) router = Router(model_list=model_list)
file_paths = ["test_questions/question1.txt", "test_questions/question2.txt", "test_questions/question3.txt"] file_paths = [
"test_questions/question1.txt",
"test_questions/question2.txt",
"test_questions/question3.txt",
]
questions = [] questions = []
for file_path in file_paths: for file_path in file_paths:
try: try:
print(file_path) print(file_path)
with open(file_path, 'r') as file: with open(file_path, "r") as file:
content = file.read() content = file.read()
questions.append(content) questions.append(content)
except FileNotFoundError as e: except FileNotFoundError as e:
@ -59,7 +68,6 @@ for file_path in file_paths:
# print(q) # print(q)
# make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array. # make X concurrent calls to litellm.completion(model=gpt-35-turbo, messages=[]), pick a random question in questions array.
# Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere # Allow me to tune X concurrent calls.. Log question, output/exception, response time somewhere
# show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions # show me a summary of requests made, success full calls, failed calls. For failed calls show me the exceptions
@ -75,7 +83,12 @@ def make_openai_completion(question):
start_time = time.time() start_time = time.time()
response = router.completion( response = router.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "system", "content": f"You are a helpful assistant. Answer this question{question}"}], messages=[
{
"role": "system",
"content": f"You are a helpful assistant. Answer this question{question}",
}
],
) )
print(response) print(response)
end_time = time.time() end_time = time.time()
@ -90,11 +103,10 @@ def make_openai_completion(question):
except Exception as e: except Exception as e:
# Log exceptions for failed calls # Log exceptions for failed calls
with open("error_log.txt", "a") as error_log_file: with open("error_log.txt", "a") as error_log_file:
error_log_file.write( error_log_file.write(f"Question: {question[:100]}\nException: {str(e)}\n\n")
f"Question: {question[:100]}\nException: {str(e)}\n\n"
)
return None return None
# Number of concurrent calls (you can adjust this) # Number of concurrent calls (you can adjust this)
concurrent_calls = 150 concurrent_calls = 150
@ -131,4 +143,3 @@ with open("request_log.txt", "r") as log_file:
with open("error_log.txt", "r") as error_log_file: with open("error_log.txt", "r") as error_log_file:
print("\nError Log:\n", error_log_file.read()) print("\nError Log:\n", error_log_file.read())

View file

@ -9,9 +9,15 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] callbacks: List[Callable] = []
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. _async_input_callback: List[
_async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here. Callable
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here. ] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[
Union[str, Callable]
] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[
Callable
] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
email: Optional[ email: Optional[
@ -44,12 +50,80 @@ use_client: bool = False
logging: bool = True logging: bool = True
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.ai/docs/caching cache: Optional[
Cache
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
model_alias_map: Dict[str, str] = {} model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers max_budget: float = 0.0 # set the max budget across all providers
_openai_completion_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] _openai_completion_params = [
_litellm_completion_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"] "functions",
"function_call",
"temperature",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"request_timeout",
"api_base",
"api_version",
"api_key",
"deployment_id",
"organization",
"base_url",
"default_headers",
"timeout",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
]
_litellm_completion_params = [
"metadata",
"acompletion",
"caching",
"mock_response",
"api_key",
"api_version",
"api_base",
"force_timeout",
"logger_fn",
"verbose",
"custom_llm_provider",
"litellm_logging_obj",
"litellm_call_id",
"use_client",
"id",
"fallbacks",
"azure",
"headers",
"model_list",
"num_retries",
"context_window_fallback_dict",
"roles",
"final_prompt_value",
"bos_token",
"eos_token",
"request_timeout",
"complete_response",
"self",
"client",
"rpm",
"tpm",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",
"model_info",
"proxy_server_request",
"preset_cache_key",
]
_current_cost = 0 # private variable, used if max budget is set _current_cost = 0 # private variable, used if max budget is set
error_logs: Dict = {} error_logs: Dict = {}
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
@ -66,23 +140,35 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None
allowed_fails: int = 0 allowed_fails: int = 0
####### SECRET MANAGERS ##################### ####### SECRET MANAGERS #####################
secret_manager_client: Optional[Any] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. secret_manager_client: Optional[
Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
############################################# #############################################
def get_model_cost_map(url: str): def get_model_cost_map(url: str):
try: try:
with requests.get(url, timeout=5) as response: # set a 5 second timeout for the get request with requests.get(
url, timeout=5
) as response: # set a 5 second timeout for the get request
response.raise_for_status() # Raise an exception if the request is unsuccessful response.raise_for_status() # Raise an exception if the request is unsuccessful
content = response.json() content = response.json()
return content return content
except Exception as e: except Exception as e:
import importlib.resources import importlib.resources
import json import json
with importlib.resources.open_text("litellm", "model_prices_and_context_window_backup.json") as f:
with importlib.resources.open_text(
"litellm", "model_prices_and_context_window_backup.json"
) as f:
content = json.load(f) content = json.load(f)
return content return content
model_cost = get_model_cost_map(url=model_cost_map_url) model_cost = get_model_cost_map(url=model_cost_map_url)
custom_prompt_dict:Dict[str, dict] = {} custom_prompt_dict: Dict[str, dict] = {}
####### THREAD-SPECIFIC DATA ################### ####### THREAD-SPECIFIC DATA ###################
class MyLocal(threading.local): class MyLocal(threading.local):
def __init__(self): def __init__(self):
@ -123,39 +209,39 @@ bedrock_models: List = []
deepinfra_models: List = [] deepinfra_models: List = []
perplexity_models: List = [] perplexity_models: List = []
for key, value in model_cost.items(): for key, value in model_cost.items():
if value.get('litellm_provider') == 'openai': if value.get("litellm_provider") == "openai":
open_ai_chat_completion_models.append(key) open_ai_chat_completion_models.append(key)
elif value.get('litellm_provider') == 'text-completion-openai': elif value.get("litellm_provider") == "text-completion-openai":
open_ai_text_completion_models.append(key) open_ai_text_completion_models.append(key)
elif value.get('litellm_provider') == 'cohere': elif value.get("litellm_provider") == "cohere":
cohere_models.append(key) cohere_models.append(key)
elif value.get('litellm_provider') == 'anthropic': elif value.get("litellm_provider") == "anthropic":
anthropic_models.append(key) anthropic_models.append(key)
elif value.get('litellm_provider') == 'openrouter': elif value.get("litellm_provider") == "openrouter":
openrouter_models.append(key) openrouter_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-text-models': elif value.get("litellm_provider") == "vertex_ai-text-models":
vertex_text_models.append(key) vertex_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-text-models': elif value.get("litellm_provider") == "vertex_ai-code-text-models":
vertex_code_text_models.append(key) vertex_code_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-language-models': elif value.get("litellm_provider") == "vertex_ai-language-models":
vertex_language_models.append(key) vertex_language_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-vision-models': elif value.get("litellm_provider") == "vertex_ai-vision-models":
vertex_vision_models.append(key) vertex_vision_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-chat-models': elif value.get("litellm_provider") == "vertex_ai-chat-models":
vertex_chat_models.append(key) vertex_chat_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models': elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
vertex_code_chat_models.append(key) vertex_code_chat_models.append(key)
elif value.get('litellm_provider') == 'ai21': elif value.get("litellm_provider") == "ai21":
ai21_models.append(key) ai21_models.append(key)
elif value.get('litellm_provider') == 'nlp_cloud': elif value.get("litellm_provider") == "nlp_cloud":
nlp_cloud_models.append(key) nlp_cloud_models.append(key)
elif value.get('litellm_provider') == 'aleph_alpha': elif value.get("litellm_provider") == "aleph_alpha":
aleph_alpha_models.append(key) aleph_alpha_models.append(key)
elif value.get('litellm_provider') == 'bedrock': elif value.get("litellm_provider") == "bedrock":
bedrock_models.append(key) bedrock_models.append(key)
elif value.get('litellm_provider') == 'deepinfra': elif value.get("litellm_provider") == "deepinfra":
deepinfra_models.append(key) deepinfra_models.append(key)
elif value.get('litellm_provider') == 'perplexity': elif value.get("litellm_provider") == "perplexity":
perplexity_models.append(key) perplexity_models.append(key)
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
@ -163,16 +249,11 @@ openai_compatible_endpoints: List = [
"api.perplexity.ai", "api.perplexity.ai",
"api.endpoints.anyscale.com/v1", "api.endpoints.anyscale.com/v1",
"api.deepinfra.com/v1/openai", "api.deepinfra.com/v1/openai",
"api.mistral.ai/v1" "api.mistral.ai/v1",
] ]
# this is maintained for Exception Mapping # this is maintained for Exception Mapping
openai_compatible_providers: List = [ openai_compatible_providers: List = ["anyscale", "mistral", "deepinfra", "perplexity"]
"anyscale",
"mistral",
"deepinfra",
"perplexity"
]
# well supported replicate llms # well supported replicate llms
@ -209,23 +290,18 @@ huggingface_models: List = [
together_ai_models: List = [ together_ai_models: List = [
# llama llms - chat # llama llms - chat
"togethercomputer/llama-2-70b-chat", "togethercomputer/llama-2-70b-chat",
# llama llms - language / instruct # llama llms - language / instruct
"togethercomputer/llama-2-70b", "togethercomputer/llama-2-70b",
"togethercomputer/LLaMA-2-7B-32K", "togethercomputer/LLaMA-2-7B-32K",
"togethercomputer/Llama-2-7B-32K-Instruct", "togethercomputer/Llama-2-7B-32K-Instruct",
"togethercomputer/llama-2-7b", "togethercomputer/llama-2-7b",
# falcon llms # falcon llms
"togethercomputer/falcon-40b-instruct", "togethercomputer/falcon-40b-instruct",
"togethercomputer/falcon-7b-instruct", "togethercomputer/falcon-7b-instruct",
# alpaca # alpaca
"togethercomputer/alpaca-7b", "togethercomputer/alpaca-7b",
# chat llms # chat llms
"HuggingFaceH4/starchat-alpha", "HuggingFaceH4/starchat-alpha",
# code llms # code llms
"togethercomputer/CodeLlama-34b", "togethercomputer/CodeLlama-34b",
"togethercomputer/CodeLlama-34b-Instruct", "togethercomputer/CodeLlama-34b-Instruct",
@ -234,29 +310,27 @@ together_ai_models: List = [
"NumbersStation/nsql-llama-2-7B", "NumbersStation/nsql-llama-2-7B",
"WizardLM/WizardCoder-15B-V1.0", "WizardLM/WizardCoder-15B-V1.0",
"WizardLM/WizardCoder-Python-34B-V1.0", "WizardLM/WizardCoder-Python-34B-V1.0",
# language llms # language llms
"NousResearch/Nous-Hermes-Llama2-13b", "NousResearch/Nous-Hermes-Llama2-13b",
"Austism/chronos-hermes-13b", "Austism/chronos-hermes-13b",
"upstage/SOLAR-0-70b-16bit", "upstage/SOLAR-0-70b-16bit",
"WizardLM/WizardLM-70B-V1.0", "WizardLM/WizardLM-70B-V1.0",
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...) ] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
baseten_models: List = ["qvv0xeq", "q841o8w", "31dxrj3"] # FALCON 7B # WizardLM # Mosaic ML baseten_models: List = [
"qvv0xeq",
"q841o8w",
"31dxrj3",
] # FALCON 7B # WizardLM # Mosaic ML
petals_models = [ petals_models = [
"petals-team/StableBeluga2", "petals-team/StableBeluga2",
] ]
ollama_models = [ ollama_models = ["llama2"]
"llama2"
]
maritalk_models = [ maritalk_models = ["maritalk"]
"maritalk"
]
model_list = ( model_list = (
open_ai_chat_completion_models open_ai_chat_completion_models
@ -327,7 +401,7 @@ models_by_provider: dict = {
"ollama": ollama_models, "ollama": ollama_models,
"deepinfra": deepinfra_models, "deepinfra": deepinfra_models,
"perplexity": perplexity_models, "perplexity": perplexity_models,
"maritalk": maritalk_models "maritalk": maritalk_models,
} }
# mapping for those models which have larger equivalents # mapping for those models which have larger equivalents
@ -362,15 +436,18 @@ cohere_embedding_models: List = [
"embed-english-light-v2.0", "embed-english-light-v2.0",
"embed-multilingual-v2.0", "embed-multilingual-v2.0",
] ]
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] bedrock_embedding_models: List = [
"amazon.titan-embed-text-v1",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models all_embedding_models = (
open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
)
####### IMAGE GENERATION MODELS ################### ####### IMAGE GENERATION MODELS ###################
openai_image_generation_models = [ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
"dall-e-2",
"dall-e-3"
]
from .timeout import timeout from .timeout import timeout
@ -398,7 +475,7 @@ from .utils import (
decode, decode,
_calculate_retry_after, _calculate_retry_after,
_should_retry, _should_retry,
get_secret get_secret,
) )
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig
@ -415,7 +492,13 @@ from .llms.vertex_ai import VertexAIConfig
from .llms.sagemaker import SagemakerConfig from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig, AmazonLlamaConfig from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,
AmazonAnthropicConfig,
AmazonCohereConfig,
AmazonLlamaConfig,
)
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig from .llms.azure import AzureOpenAIConfig
from .main import * # type: ignore from .main import * # type: ignore
@ -434,7 +517,7 @@ from .exceptions import (
Timeout, Timeout,
APIConnectionError, APIConnectionError,
APIResponseValidationError, APIResponseValidationError,
UnprocessableEntityError UnprocessableEntityError,
) )
from .budget_manager import BudgetManager from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server from .proxy.proxy_cli import run_server

View file

@ -1,5 +1,6 @@
set_verbose = False set_verbose = False
def print_verbose(print_statement): def print_verbose(print_statement):
try: try:
if set_verbose: if set_verbose:

View file

@ -13,6 +13,7 @@ import inspect
import redis, litellm import redis, litellm
from typing import List, Optional from typing import List, Optional
def _get_redis_kwargs(): def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis) arg_spec = inspect.getfullargspec(redis.Redis)
@ -23,23 +24,17 @@ def _get_redis_kwargs():
"retry", "retry",
} }
include_args = ["url"]
include_args = [ available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
"url"
]
available_args = [
x for x in arg_spec.args if x not in exclude_args
] + include_args
return available_args return available_args
def _get_redis_env_kwarg_mapping(): def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_" PREFIX = "REDIS_"
return { return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()
}
def _redis_kwargs_from_environment(): def _redis_kwargs_from_environment():
@ -58,14 +53,19 @@ def get_redis_url_from_environment():
return os.environ["REDIS_URL"] return os.environ["REDIS_URL"]
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ: if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.") raise ValueError(
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
)
if "REDIS_PASSWORD" in os.environ: if "REDIS_PASSWORD" in os.environ:
redis_password = f":{os.environ['REDIS_PASSWORD']}@" redis_password = f":{os.environ['REDIS_PASSWORD']}@"
else: else:
redis_password = "" redis_password = ""
return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}" return (
f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
)
def get_redis_client(**env_overrides): def get_redis_client(**env_overrides):
### check if "os.environ/<key-name>" passed in ### check if "os.environ/<key-name>" passed in
@ -80,14 +80,14 @@ def get_redis_client(**env_overrides):
**env_overrides, **env_overrides,
} }
if "url" in redis_kwargs and redis_kwargs['url'] is not None: if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop("host", None) redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None) redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None) redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None) redis_kwargs.pop("password", None)
return redis.Redis.from_url(**redis_kwargs) return redis.Redis.from_url(**redis_kwargs)
elif "host" not in redis_kwargs or redis_kwargs['host'] is None: elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.") raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis.Redis(**redis_kwargs) return redis.Redis(**redis_kwargs)

View file

@ -4,8 +4,14 @@ from litellm.utils import ModelResponse
import requests, threading import requests, threading
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
class BudgetManager: class BudgetManager:
def __init__(self, project_name: str, client_type: str = "local", api_base: Optional[str] = None): def __init__(
self,
project_name: str,
client_type: str = "local",
api_base: Optional[str] = None,
):
self.client_type = client_type self.client_type = client_type
self.project_name = project_name self.project_name = project_name
self.api_base = api_base or "https://api.litellm.ai" self.api_base = api_base or "https://api.litellm.ai"
@ -16,6 +22,7 @@ class BudgetManager:
try: try:
if litellm.set_verbose: if litellm.set_verbose:
import logging import logging
logging.info(print_statement) logging.info(print_statement)
except: except:
pass pass
@ -25,7 +32,7 @@ class BudgetManager:
# Check if user dict file exists # Check if user dict file exists
if os.path.isfile("user_cost.json"): if os.path.isfile("user_cost.json"):
# Load the user dict # Load the user dict
with open("user_cost.json", 'r') as json_file: with open("user_cost.json", "r") as json_file:
self.user_dict = json.load(json_file) self.user_dict = json.load(json_file)
else: else:
self.print_verbose("User Dictionary not found!") self.print_verbose("User Dictionary not found!")
@ -34,40 +41,55 @@ class BudgetManager:
elif self.client_type == "hosted": elif self.client_type == "hosted":
# Load the user_dict from hosted db # Load the user_dict from hosted db
url = self.api_base + "/get_budget" url = self.api_base + "/get_budget"
headers = {'Content-Type': 'application/json'} headers = {"Content-Type": "application/json"}
data = { data = {"project_name": self.project_name}
'project_name' : self.project_name
}
response = requests.post(url, headers=headers, json=data) response = requests.post(url, headers=headers, json=data)
response = response.json() response = response.json()
if response["status"] == "error": if response["status"] == "error":
self.user_dict = {} # assume this means the user dict hasn't been stored yet self.user_dict = (
{}
) # assume this means the user dict hasn't been stored yet
else: else:
self.user_dict = response["data"] self.user_dict = response["data"]
def create_budget(self, total_budget: float, user: str, duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, created_at: float = time.time()): def create_budget(
self,
total_budget: float,
user: str,
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
created_at: float = time.time(),
):
self.user_dict[user] = {"total_budget": total_budget} self.user_dict[user] = {"total_budget": total_budget}
if duration is None: if duration is None:
return self.user_dict[user] return self.user_dict[user]
if duration == 'daily': if duration == "daily":
duration_in_days = 1 duration_in_days = 1
elif duration == 'weekly': elif duration == "weekly":
duration_in_days = 7 duration_in_days = 7
elif duration == 'monthly': elif duration == "monthly":
duration_in_days = 28 duration_in_days = 28
elif duration == 'yearly': elif duration == "yearly":
duration_in_days = 365 duration_in_days = 365
else: else:
raise ValueError("""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""") raise ValueError(
self.user_dict[user] = {"total_budget": total_budget, "duration": duration_in_days, "created_at": created_at, "last_updated_at": created_at} """duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
)
self.user_dict[user] = {
"total_budget": total_budget,
"duration": duration_in_days,
"created_at": created_at,
"last_updated_at": created_at,
}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return self.user_dict[user] return self.user_dict[user]
def projected_cost(self, model: str, messages: list, user: str): def projected_cost(self, model: str, messages: list, user: str):
text = "".join(message["content"] for message in messages) text = "".join(message["content"] for message in messages)
prompt_tokens = litellm.token_counter(model=model, text=text) prompt_tokens = litellm.token_counter(model=model, text=text)
prompt_cost, _ = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=0) prompt_cost, _ = litellm.cost_per_token(
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
)
current_cost = self.user_dict[user].get("current_cost", 0) current_cost = self.user_dict[user].get("current_cost", 0)
projected_cost = prompt_cost + current_cost projected_cost = prompt_cost + current_cost
return projected_cost return projected_cost
@ -75,28 +97,53 @@ class BudgetManager:
def get_total_budget(self, user: str): def get_total_budget(self, user: str):
return self.user_dict[user]["total_budget"] return self.user_dict[user]["total_budget"]
def update_cost(self, user: str, completion_obj: Optional[ModelResponse] = None, model: Optional[str] = None, input_text: Optional[str] = None, output_text: Optional[str] = None): def update_cost(
self,
user: str,
completion_obj: Optional[ModelResponse] = None,
model: Optional[str] = None,
input_text: Optional[str] = None,
output_text: Optional[str] = None,
):
if model and input_text and output_text: if model and input_text and output_text:
prompt_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": input_text}]) prompt_tokens = litellm.token_counter(
completion_tokens = litellm.token_counter(model=model, messages=[{"role": "user", "content": output_text}]) model=model, messages=[{"role": "user", "content": input_text}]
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) )
completion_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": output_text}]
)
(
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
) = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
elif completion_obj: elif completion_obj:
cost = litellm.completion_cost(completion_response=completion_obj) cost = litellm.completion_cost(completion_response=completion_obj)
model = completion_obj['model'] # if this throws an error try, model = completion_obj['model'] model = completion_obj[
"model"
] # if this throws an error try, model = completion_obj['model']
else: else:
raise ValueError("Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager") raise ValueError(
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
)
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get("current_cost", 0) self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
"current_cost", 0
)
if "model_cost" in self.user_dict[user]: if "model_cost" in self.user_dict[user]:
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user]["model_cost"].get(model, 0) self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
"model_cost"
].get(model, 0)
else: else:
self.user_dict[user]["model_cost"] = {model: cost} self.user_dict[user]["model_cost"] = {model: cost}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return {"user": self.user_dict[user]} return {"user": self.user_dict[user]}
def get_current_cost(self, user): def get_current_cost(self, user):
return self.user_dict[user].get("current_cost", 0) return self.user_dict[user].get("current_cost", 0)
@ -135,7 +182,9 @@ class BudgetManager:
self.reset_on_duration(user) self.reset_on_duration(user)
def _save_data_thread(self): def _save_data_thread(self):
thread = threading.Thread(target=self.save_data) # [Non-Blocking]: saves data without blocking execution thread = threading.Thread(
target=self.save_data
) # [Non-Blocking]: saves data without blocking execution
thread.start() thread.start()
def save_data(self): def save_data(self):
@ -143,16 +192,15 @@ class BudgetManager:
import json import json
# save the user dict # save the user dict
with open("user_cost.json", 'w') as json_file: with open("user_cost.json", "w") as json_file:
json.dump(self.user_dict, json_file, indent=4) # Indent for pretty formatting json.dump(
self.user_dict, json_file, indent=4
) # Indent for pretty formatting
return {"status": "success"} return {"status": "success"}
elif self.client_type == "hosted": elif self.client_type == "hosted":
url = self.api_base + "/set_budget" url = self.api_base + "/set_budget"
headers = {'Content-Type': 'application/json'} headers = {"Content-Type": "application/json"}
data = { data = {"project_name": self.project_name, "user_dict": self.user_dict}
'project_name' : self.project_name,
"user_dict": self.user_dict
}
response = requests.post(url, headers=headers, json=data) response = requests.post(url, headers=headers, json=data)
response = response.json() response = response.json()
return response return response

View file

@ -12,6 +12,7 @@ import time, logging
import json, traceback, ast import json, traceback, ast
from typing import Optional, Literal, List from typing import Optional, Literal, List
def print_verbose(print_statement): def print_verbose(print_statement):
try: try:
if litellm.set_verbose: if litellm.set_verbose:
@ -19,6 +20,7 @@ def print_verbose(print_statement):
except: except:
pass pass
class BaseCache: class BaseCache:
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
raise NotImplementedError raise NotImplementedError
@ -60,6 +62,7 @@ class InMemoryCache(BaseCache):
class RedisCache(BaseCache): class RedisCache(BaseCache):
def __init__(self, host=None, port=None, password=None, **kwargs): def __init__(self, host=None, port=None, password=None, **kwargs):
import redis import redis
# if users don't provider one, use the default litellm cache # if users don't provider one, use the default litellm cache
from ._redis import get_redis_client from ._redis import get_redis_client
@ -88,12 +91,18 @@ class RedisCache(BaseCache):
try: try:
print_verbose(f"Get Redis Cache: key: {key}") print_verbose(f"Get Redis Cache: key: {key}")
cached_response = self.redis_client.get(key) cached_response = self.redis_client.get(key)
print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}") print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
)
if cached_response != None: if cached_response != None:
# cached_response is in `b{} convert it to ModelResponse # cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string cached_response = cached_response.decode(
"utf-8"
) # Convert bytes to string
try: try:
cached_response = json.loads(cached_response) # Convert string to dictionary cached_response = json.loads(
cached_response
) # Convert string to dictionary
except: except:
cached_response = ast.literal_eval(cached_response) cached_response = ast.literal_eval(cached_response)
return cached_response return cached_response
@ -105,13 +114,19 @@ class RedisCache(BaseCache):
def flush_cache(self): def flush_cache(self):
self.redis_client.flushall() self.redis_client.flushall()
class DualCache(BaseCache): class DualCache(BaseCache):
""" """
This updates both Redis and an in-memory cache simultaneously. This updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis. When data is updated or inserted, it is written to both the in-memory cache + Redis.
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data. This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
""" """
def __init__(self, in_memory_cache: Optional[InMemoryCache] =None, redis_cache: Optional[RedisCache] =None) -> None:
def __init__(
self,
in_memory_cache: Optional[InMemoryCache] = None,
redis_cache: Optional[RedisCache] = None,
) -> None:
super().__init__() super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache # If in_memory_cache is not provided, use the default InMemoryCache
self.in_memory_cache = in_memory_cache or InMemoryCache() self.in_memory_cache = in_memory_cache or InMemoryCache()
@ -162,6 +177,7 @@ class DualCache(BaseCache):
if self.redis_cache is not None: if self.redis_cache is not None:
self.redis_cache.flush_cache() self.redis_cache.flush_cache()
#### LiteLLM.Completion / Embedding Cache #### #### LiteLLM.Completion / Embedding Cache ####
class Cache: class Cache:
def __init__( def __init__(
@ -170,8 +186,10 @@ class Cache:
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
supported_call_types: Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"], supported_call_types: Optional[
**kwargs List[Literal["completion", "acompletion", "embedding", "aembedding"]]
] = ["completion", "acompletion", "embedding", "aembedding"],
**kwargs,
): ):
""" """
Initializes the cache based on the given type. Initializes the cache based on the given type.
@ -222,8 +240,27 @@ class Cache:
return kwargs.get("litellm_params", {}).get("preset_cache_key", None) return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4] # sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"] completion_kwargs = [
embedding_only_kwargs = ["input", "encoding_format"] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs "model",
"messages",
"temperature",
"top_p",
"n",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
]
embedding_only_kwargs = [
"input",
"encoding_format",
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set() # combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = completion_kwargs + embedding_only_kwargs combined_kwargs = completion_kwargs + embedding_only_kwargs
@ -255,19 +292,30 @@ class Cache:
if model_group in group: if model_group in group:
caching_group = group caching_group = group
break break
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"] param_value = (
caching_group or model_group or kwargs[param]
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
else: else:
if kwargs[param] is None: if kwargs[param] is None:
continue # ignore None params continue # ignore None params
param_value = kwargs[param] param_value = kwargs[param]
cache_key+= f"{str(param)}: {str(param_value)}" cache_key += f"{str(param)}: {str(param_value)}"
print_verbose(f"\nCreated cache key: {cache_key}") print_verbose(f"\nCreated cache key: {cache_key}")
return cache_key return cache_key
def generate_streaming_content(self, content): def generate_streaming_content(self, content):
chunk_size = 5 # Adjust the chunk size as needed chunk_size = 5 # Adjust the chunk size as needed
for i in range(0, len(content), chunk_size): for i in range(0, len(content), chunk_size):
yield {'choices': [{'delta': {'role': 'assistant', 'content': content[i:i + chunk_size]}}]} yield {
"choices": [
{
"delta": {
"role": "assistant",
"content": content[i : i + chunk_size],
}
}
]
}
time.sleep(0.02) time.sleep(0.02)
def get_cache(self, *args, **kwargs): def get_cache(self, *args, **kwargs):

View file

@ -48,7 +48,6 @@
# # print("\033[92mLiteLLM: Switched on Redis caching\033[0m") # # print("\033[92mLiteLLM: Switched on Redis caching\033[0m")
# def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'): # def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'):
# config = {} # config = {}
# server_settings = {} # server_settings = {}

View file

@ -20,7 +20,7 @@ from openai import (
APITimeoutError, APITimeoutError,
APIConnectionError, APIConnectionError,
APIResponseValidationError, APIResponseValidationError,
UnprocessableEntityError UnprocessableEntityError,
) )
import httpx import httpx
@ -32,11 +32,10 @@ class AuthenticationError(AuthenticationError): # type: ignore
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# raise when invalid models passed, example gpt-8 # raise when invalid models passed, example gpt-8
class NotFoundError(NotFoundError): # type: ignore class NotFoundError(NotFoundError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(self, message, model, llm_provider, response: httpx.Response):
@ -45,9 +44,7 @@ class NotFoundError(NotFoundError): # type: ignore
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -58,11 +55,10 @@ class BadRequestError(BadRequestError): # type: ignore
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 422 self.status_code = 422
@ -70,11 +66,10 @@ class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class Timeout(APITimeoutError): # type: ignore class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
self.status_code = 408 self.status_code = 408
@ -86,6 +81,7 @@ class Timeout(APITimeoutError): # type: ignore
request=request request=request
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore class RateLimitError(RateLimitError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 429 self.status_code = 429
@ -93,11 +89,10 @@ class RateLimitError(RateLimitError): # type: ignore
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.modle = model self.modle = model
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors # sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(BadRequestError): # type: ignore class ContextWindowExceededError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(self, message, model, llm_provider, response: httpx.Response):
@ -109,9 +104,10 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
message=self.message, message=self.message,
model=self.model, # type: ignore model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore llm_provider=self.llm_provider, # type: ignore
response=response response=response,
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(APIStatusError): # type: ignore class ServiceUnavailableError(APIStatusError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 503 self.status_code = 503
@ -119,24 +115,21 @@ class ServiceUnavailableError(APIStatusError): # type: ignore
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
super().__init__( super().__init__(
self.message, self.message, response=response, body=None
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401 # raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(APIError): # type: ignore class APIError(APIError): # type: ignore
def __init__(self, status_code, message, llm_provider, model, request: httpx.Request): def __init__(
self, status_code, message, llm_provider, model, request: httpx.Request
):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
super().__init__( super().__init__(self.message, request=request, body=None) # type: ignore
self.message,
request=request, # type: ignore
body=None
)
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIConnectionError(APIConnectionError): # type: ignore class APIConnectionError(APIConnectionError): # type: ignore
@ -145,10 +138,8 @@ class APIConnectionError(APIConnectionError): # type: ignore
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.status_code = 500 self.status_code = 500
super().__init__( super().__init__(message=self.message, request=request)
message=self.message,
request=request
)
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
class APIResponseValidationError(APIResponseValidationError): # type: ignore class APIResponseValidationError(APIResponseValidationError): # type: ignore
@ -158,11 +149,8 @@ class APIResponseValidationError(APIResponseValidationError): # type: ignore
self.model = model self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request) response = httpx.Response(status_code=500, request=request)
super().__init__( super().__init__(response=response, body=None, message=message)
response=response,
body=None,
message=message
)
class OpenAIError(OpenAIError): # type: ignore class OpenAIError(OpenAIError): # type: ignore
def __init__(self, original_exception): def __init__(self, original_exception):
@ -176,6 +164,7 @@ class OpenAIError(OpenAIError): # type: ignore
) )
self.llm_provider = "openai" self.llm_provider = "openai"
class BudgetExceededError(Exception): class BudgetExceededError(Exception):
def __init__(self, current_cost, max_budget): def __init__(self, current_cost, max_budget):
self.current_cost = current_cost self.current_cost = current_cost
@ -183,6 +172,7 @@ class BudgetExceededError(Exception):
message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}" message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
super().__init__(message) super().__init__(message)
## DEPRECATED ## ## DEPRECATED ##
class InvalidRequestError(BadRequestError): # type: ignore class InvalidRequestError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):

View file

@ -5,6 +5,7 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal from typing import Literal
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -47,10 +48,19 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
""" """
Control the modify incoming / outgoung data before calling the model Control the modify incoming / outgoung data before calling the model
""" """
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal["completion", "embeddings"],
):
pass pass
async def async_post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth): async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
pass pass
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
@ -63,14 +73,14 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
callback_func( callback_func(
kwargs, kwargs,
) )
print_verbose( print_verbose(f"Custom Logger - model call details: {kwargs}")
f"Custom Logger - model call details: {kwargs}"
)
except: except:
traceback.print_exc() traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event(self, model, messages, kwargs, print_verbose, callback_func): async def async_log_input_event(
self, model, messages, kwargs, print_verbose, callback_func
):
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
@ -78,15 +88,14 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
await callback_func( await callback_func(
kwargs, kwargs,
) )
print_verbose( print_verbose(f"Custom Logger - model call details: {kwargs}")
f"Custom Logger - model call details: {kwargs}"
)
except: except:
traceback.print_exc() traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event(
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func): self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
):
# Method definition # Method definition
try: try:
kwargs["log_event_type"] = "post_api_call" kwargs["log_event_type"] = "post_api_call"
@ -96,15 +105,15 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
start_time, start_time,
end_time, end_time,
) )
print_verbose( print_verbose(f"Custom Logger - final response object: {response_obj}")
f"Custom Logger - final response object: {response_obj}"
)
except: except:
# traceback.print_exc() # traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass pass
async def async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func): async def async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
):
# Method definition # Method definition
try: try:
kwargs["log_event_type"] = "post_api_call" kwargs["log_event_type"] = "post_api_call"
@ -114,9 +123,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
start_time, start_time,
end_time, end_time,
) )
print_verbose( print_verbose(f"Custom Logger - final response object: {response_obj}")
f"Custom Logger - final response object: {response_obj}"
)
except: except:
# traceback.print_exc() # traceback.print_exc()
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")

View file

@ -10,19 +10,28 @@ import datetime, subprocess, sys
import litellm, uuid import litellm, uuid
from litellm._logging import print_verbose from litellm._logging import print_verbose
class DyanmoDBLogger: class DyanmoDBLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
# Instance variables # Instance variables
import boto3 import boto3
self.dynamodb = boto3.resource('dynamodb', region_name=os.environ["AWS_REGION_NAME"])
self.dynamodb = boto3.resource(
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
)
if litellm.dynamodb_table_name is None: if litellm.dynamodb_table_name is None:
raise ValueError("LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`") raise ValueError(
"LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`"
)
self.table_name = litellm.dynamodb_table_name self.table_name = litellm.dynamodb_table_name
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose
):
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose) self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
try: try:
print_verbose( print_verbose(
@ -32,7 +41,9 @@ class DyanmoDBLogger:
# construct payload to send to DynamoDB # construct payload to send to DynamoDB
# follows the same params as langfuse.py # follows the same params as langfuse.py
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages") messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {}) optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion") call_type = kwargs.get("call_type", "litellm.completion")
@ -51,7 +62,7 @@ class DyanmoDBLogger:
"messages": messages, "messages": messages,
"response": response_obj, "response": response_obj,
"usage": usage, "usage": usage,
"metadata": metadata "metadata": metadata,
} }
# Ensure everything in the payload is converted to str # Ensure everything in the payload is converted to str
@ -62,7 +73,6 @@ class DyanmoDBLogger:
# non blocking if it can't cast to a str # non blocking if it can't cast to a str
pass pass
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}") print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
# put data in dyanmo DB # put data in dyanmo DB

View file

@ -64,7 +64,9 @@ class LangFuseLogger:
# end of processing langfuse ######################## # end of processing langfuse ########################
input = prompt input = prompt
output = response_obj["choices"][0]["message"].json() output = response_obj["choices"][0]["message"].json()
print(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}") print(
f"OUTPUT IN LANGFUSE: {output}; original: {response_obj['choices'][0]['message']}"
)
self._log_langfuse_v2( self._log_langfuse_v2(
user_id, user_id,
metadata, metadata,
@ -171,7 +173,6 @@ class LangFuseLogger:
user_id=user_id, user_id=user_id,
) )
trace.generation( trace.generation(
name=metadata.get("generation_name", "litellm-completion"), name=metadata.get("generation_name", "litellm-completion"),
startTime=start_time, startTime=start_time,

View file

@ -8,12 +8,12 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
class LangsmithLogger: class LangsmithLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY") self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
# Method definition # Method definition
# inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb # inspired by Langsmith http api here: https://github.com/langchain-ai/langsmith-cookbook/blob/main/tracing-examples/rest/rest.ipynb
@ -26,7 +26,9 @@ class LangsmithLogger:
# if not set litellm will use default project_name = litellm-completion, run_name = LLMRun # if not set litellm will use default project_name = litellm-completion, run_name = LLMRun
project_name = metadata.get("project_name", "litellm-completion") project_name = metadata.get("project_name", "litellm-completion")
run_name = metadata.get("run_name", "LLMRun") run_name = metadata.get("run_name", "LLMRun")
print_verbose(f"Langsmith Logging - project_name: {project_name}, run_name {run_name}") print_verbose(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
try: try:
print_verbose( print_verbose(
f"Langsmith Logging - Enters logging function for model {kwargs}" f"Langsmith Logging - Enters logging function for model {kwargs}"
@ -34,6 +36,7 @@ class LangsmithLogger:
import requests import requests
import datetime import datetime
from datetime import timezone from datetime import timezone
try: try:
start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat() start_time = kwargs["start_time"].astimezone(timezone.utc).isoformat()
end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat() end_time = kwargs["end_time"].astimezone(timezone.utc).isoformat()
@ -45,7 +48,7 @@ class LangsmithLogger:
new_kwargs = {} new_kwargs = {}
for key in kwargs: for key in kwargs:
value = kwargs[key] value = kwargs[key]
if key == "start_time" or key =="end_time": if key == "start_time" or key == "end_time":
pass pass
elif type(value) != dict: elif type(value) != dict:
new_kwargs[key] = value new_kwargs[key] = value
@ -55,17 +58,13 @@ class LangsmithLogger:
json={ json={
"name": run_name, "name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": { "inputs": {**new_kwargs},
**new_kwargs
},
"outputs": response_obj.json(), "outputs": response_obj.json(),
"session_name": project_name, "session_name": project_name,
"start_time": start_time, "start_time": start_time,
"end_time": end_time, "end_time": end_time,
}, },
headers={ headers={"x-api-key": self.langsmith_api_key},
"x-api-key": self.langsmith_api_key
}
) )
print_verbose( print_verbose(
f"Langsmith Layer Logging - final response object: {response_obj}" f"Langsmith Layer Logging - final response object: {response_obj}"

View file

@ -1,6 +1,7 @@
import requests, traceback, json, os import requests, traceback, json, os
import types import types
class LiteDebugger: class LiteDebugger:
user_email = None user_email = None
dashboard_url = None dashboard_url = None
@ -12,9 +13,15 @@ class LiteDebugger:
def validate_environment(self, email): def validate_environment(self, email):
try: try:
self.user_email = (email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")) self.user_email = (
if self.user_email == None: # if users are trying to use_client=True but token not set email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
raise ValueError("litellm.use_client = True but no token or email passed. Please set it in litellm.token") )
if (
self.user_email == None
): # if users are trying to use_client=True but token not set
raise ValueError(
"litellm.use_client = True but no token or email passed. Please set it in litellm.token"
)
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
try: try:
print( print(
@ -42,7 +49,9 @@ class LiteDebugger:
litellm_params, litellm_params,
optional_params, optional_params,
): ):
print_verbose(f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}") print_verbose(
f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}"
)
try: try:
print_verbose( print_verbose(
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}" f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
@ -56,7 +65,11 @@ class LiteDebugger:
updated_litellm_params = remove_key_value(litellm_params, "logger_fn") updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
if call_type == "embedding": if call_type == "embedding":
for message in messages: # assuming the input is a list as required by the embedding function for (
message
) in (
messages
): # assuming the input is a list as required by the embedding function
litellm_data_obj = { litellm_data_obj = {
"model": model, "model": model,
"messages": [{"role": "user", "content": message}], "messages": [{"role": "user", "content": message}],
@ -79,7 +92,9 @@ class LiteDebugger:
elif call_type == "completion": elif call_type == "completion":
litellm_data_obj = { litellm_data_obj = {
"model": model, "model": model,
"messages": messages if isinstance(messages, list) else [{"role": "user", "content": messages}], "messages": messages
if isinstance(messages, list)
else [{"role": "user", "content": messages}],
"end_user": end_user, "end_user": end_user,
"status": "initiated", "status": "initiated",
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
@ -95,20 +110,30 @@ class LiteDebugger:
headers={"content-type": "application/json"}, headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj), data=json.dumps(litellm_data_obj),
) )
print_verbose(f"LiteDebugger: completion api response - {response.text}") print_verbose(
f"LiteDebugger: completion api response - {response.text}"
)
except: except:
print_verbose( print_verbose(
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}" f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
) )
pass pass
def post_call_log_event(self, original_response, litellm_call_id, print_verbose, call_type, stream): def post_call_log_event(
print_verbose(f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}") self, original_response, litellm_call_id, print_verbose, call_type, stream
):
print_verbose(
f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}"
)
try: try:
if call_type == "embedding": if call_type == "embedding":
litellm_data_obj = { litellm_data_obj = {
"status": "received", "status": "received",
"additional_details": {"original_response": str(original_response["data"][0]["embedding"][:5])}, # don't store the entire vector "additional_details": {
"original_response": str(
original_response["data"][0]["embedding"][:5]
)
}, # don't store the entire vector
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"user_email": self.user_email, "user_email": self.user_email,
} }
@ -122,7 +147,11 @@ class LiteDebugger:
elif call_type == "completion" and stream: elif call_type == "completion" and stream:
litellm_data_obj = { litellm_data_obj = {
"status": "received", "status": "received",
"additional_details": {"original_response": "Streamed response" if isinstance(original_response, types.GeneratorType) else original_response}, "additional_details": {
"original_response": "Streamed response"
if isinstance(original_response, types.GeneratorType)
else original_response
},
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"user_email": self.user_email, "user_email": self.user_email,
} }
@ -147,9 +176,11 @@ class LiteDebugger:
litellm_call_id, litellm_call_id,
print_verbose, print_verbose,
call_type, call_type,
stream = False stream=False,
): ):
print_verbose(f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}") print_verbose(
f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}"
)
try: try:
print_verbose( print_verbose(
f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}" f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}"

View file

@ -18,19 +18,17 @@ class PromptLayerLogger:
# Method definition # Method definition
try: try:
new_kwargs = {} new_kwargs = {}
new_kwargs['model'] = kwargs['model'] new_kwargs["model"] = kwargs["model"]
new_kwargs['messages'] = kwargs['messages'] new_kwargs["messages"] = kwargs["messages"]
# add kwargs["optional_params"] to new_kwargs # add kwargs["optional_params"] to new_kwargs
for optional_param in kwargs["optional_params"]: for optional_param in kwargs["optional_params"]:
new_kwargs[optional_param] = kwargs["optional_params"][optional_param] new_kwargs[optional_param] = kwargs["optional_params"][optional_param]
print_verbose( print_verbose(
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}" f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
) )
request_response = requests.post( request_response = requests.post(
"https://api.promptlayer.com/rest/track-request", "https://api.promptlayer.com/rest/track-request",
json={ json={
@ -62,10 +60,12 @@ class PromptLayerLogger:
json={ json={
"request_id": response_json["request_id"], "request_id": response_json["request_id"],
"api_key": self.key, "api_key": self.key,
"metadata": kwargs["litellm_params"]["metadata"] "metadata": kwargs["litellm_params"]["metadata"],
}, },
) )
print_verbose(f"Prompt Layer Logging: success - metadata post response object: {response.text}") print_verbose(
f"Prompt Layer Logging: success - metadata post response object: {response.text}"
)
except: except:
print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}") print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")

View file

@ -9,6 +9,7 @@ import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm import litellm
class Supabase: class Supabase:
# Class variables or attributes # Class variables or attributes
supabase_table_name = "request_logs" supabase_table_name = "request_logs"

View file

@ -2,6 +2,7 @@ class TraceloopLogger:
def __init__(self): def __init__(self):
from traceloop.sdk.tracing.tracing import TracerWrapper from traceloop.sdk.tracing.tracing import TracerWrapper
from traceloop.sdk import Traceloop from traceloop.sdk import Traceloop
Traceloop.init(app_name="Litellm-Server", disable_batch=True) Traceloop.init(app_name="Litellm-Server", disable_batch=True)
self.tracer_wrapper = TracerWrapper() self.tracer_wrapper = TracerWrapper()
@ -29,15 +30,18 @@ class TraceloopLogger:
) )
if "stop" in optional_params: if "stop" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_CHAT_STOP_SEQUENCES, optional_params.get("stop") SpanAttributes.LLM_CHAT_STOP_SEQUENCES,
optional_params.get("stop"),
) )
if "frequency_penalty" in optional_params: if "frequency_penalty" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_FREQUENCY_PENALTY, optional_params.get("frequency_penalty") SpanAttributes.LLM_FREQUENCY_PENALTY,
optional_params.get("frequency_penalty"),
) )
if "presence_penalty" in optional_params: if "presence_penalty" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_PRESENCE_PENALTY, optional_params.get("presence_penalty") SpanAttributes.LLM_PRESENCE_PENALTY,
optional_params.get("presence_penalty"),
) )
if "top_p" in optional_params: if "top_p" in optional_params:
span.set_attribute( span.set_attribute(
@ -45,7 +49,10 @@ class TraceloopLogger:
) )
if "tools" in optional_params or "functions" in optional_params: if "tools" in optional_params or "functions" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_REQUEST_FUNCTIONS, optional_params.get("tools", optional_params.get("functions")) SpanAttributes.LLM_REQUEST_FUNCTIONS,
optional_params.get(
"tools", optional_params.get("functions")
),
) )
if "user" in optional_params: if "user" in optional_params:
span.set_attribute( span.set_attribute(
@ -53,7 +60,8 @@ class TraceloopLogger:
) )
if "max_tokens" in optional_params: if "max_tokens" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_REQUEST_MAX_TOKENS, kwargs.get("max_tokens") SpanAttributes.LLM_REQUEST_MAX_TOKENS,
kwargs.get("max_tokens"),
) )
if "temperature" in optional_params: if "temperature" in optional_params:
span.set_attribute( span.set_attribute(

View file

@ -1,4 +1,4 @@
imported_openAIResponse=True imported_openAIResponse = True
try: try:
import io import io
import logging import logging
@ -12,14 +12,11 @@ try:
else: else:
from typing_extensions import Literal, Protocol from typing_extensions import Literal, Protocol
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
K = TypeVar("K", bound=str) K = TypeVar("K", bound=str)
V = TypeVar("V") V = TypeVar("V")
class OpenAIResponse(Protocol[K, V]): # type: ignore class OpenAIResponse(Protocol[K, V]): # type: ignore
# contains a (known) object attribute # contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"] object: Literal["chat.completion", "edit", "text_completion"]
@ -30,7 +27,6 @@ try:
def get(self, key: K, default: Optional[V] = None) -> Optional[V]: def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
... # pragma: no cover ... # pragma: no cover
class OpenAIRequestResponseResolver: class OpenAIRequestResponseResolver:
def __call__( def __call__(
self, self,
@ -44,7 +40,9 @@ try:
elif response["object"] == "text_completion": elif response["object"] == "text_completion":
return self._resolve_completion(request, response, time_elapsed) return self._resolve_completion(request, response, time_elapsed)
elif response["object"] == "chat.completion": elif response["object"] == "chat.completion":
return self._resolve_chat_completion(request, response, time_elapsed) return self._resolve_chat_completion(
request, response, time_elapsed
)
else: else:
logger.info(f"Unknown OpenAI response object: {response['object']}") logger.info(f"Unknown OpenAI response object: {response['object']}")
except Exception as e: except Exception as e:
@ -113,7 +111,8 @@ try:
"""Resolves the request and response objects for `openai.Completion`.""" """Resolves the request and response objects for `openai.Completion`."""
request_str = f"\n\n**Prompt**: {request['prompt']}\n" request_str = f"\n\n**Prompt**: {request['prompt']}\n"
choices = [ choices = [
f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"] f"\n\n**Completion**: {choice['text']}\n"
for choice in response["choices"]
] ]
return self._request_response_result_to_trace( return self._request_response_result_to_trace(
@ -167,9 +166,9 @@ try:
] ]
trace = self.results_to_trace_tree(request, response, results, time_elapsed) trace = self.results_to_trace_tree(request, response, results, time_elapsed)
return trace return trace
except:
imported_openAIResponse=False
except:
imported_openAIResponse = False
#### What this does #### #### What this does ####
@ -182,15 +181,20 @@ from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
class WeightsBiasesLogger: class WeightsBiasesLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
try: try:
import wandb import wandb
except: except:
raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m") raise Exception(
if imported_openAIResponse==False: "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
raise Exception("\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m") )
if imported_openAIResponse == False:
raise Exception(
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
)
self.resolver = OpenAIRequestResponseResolver() self.resolver = OpenAIRequestResponseResolver()
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
@ -198,13 +202,13 @@ class WeightsBiasesLogger:
import wandb import wandb
try: try:
print_verbose( print_verbose(f"W&B Logging - Enters logging function for model {kwargs}")
f"W&B Logging - Enters logging function for model {kwargs}"
)
run = wandb.init() run = wandb.init()
print_verbose(response_obj) print_verbose(response_obj)
trace = self.resolver(kwargs, response_obj, (end_time-start_time).total_seconds()) trace = self.resolver(
kwargs, response_obj, (end_time - start_time).total_seconds()
)
if trace is not None: if trace is not None:
run.log({"trace": trace}) run.log({"trace": trace})

View file

@ -7,17 +7,21 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message
import litellm import litellm
class AI21Error(Exception): class AI21Error(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/") self.request = httpx.Request(
method="POST", url="https://api.ai21.com/studio/v1/"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AI21Config():
class AI21Config:
""" """
Reference: https://docs.ai21.com/reference/j2-complete-ref Reference: https://docs.ai21.com/reference/j2-complete-ref
@ -43,40 +47,53 @@ class AI21Config():
- `countPenalty` (object): Placeholder for count penalty object. - `countPenalty` (object): Placeholder for count penalty object.
""" """
numResults: Optional[int]=None
maxTokens: Optional[int]=None
minTokens: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
stopSequences: Optional[list]=None
topKReturn: Optional[int]=None
frequencePenalty: Optional[dict]=None
presencePenalty: Optional[dict]=None
countPenalty: Optional[dict]=None
def __init__(self, numResults: Optional[int] = None
numResults: Optional[int]=None, maxTokens: Optional[int] = None
maxTokens: Optional[int]=None, minTokens: Optional[int] = None
minTokens: Optional[int]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, topP: Optional[float] = None
topP: Optional[float]=None, stopSequences: Optional[list] = None
stopSequences: Optional[list]=None, topKReturn: Optional[int] = None
topKReturn: Optional[int]=None, frequencePenalty: Optional[dict] = None
frequencePenalty: Optional[dict]=None, presencePenalty: Optional[dict] = None
presencePenalty: Optional[dict]=None, countPenalty: Optional[dict] = None
countPenalty: Optional[dict]=None) -> None:
def __init__(
self,
numResults: Optional[int] = None,
maxTokens: Optional[int] = None,
minTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
topKReturn: Optional[int] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
@ -91,6 +108,7 @@ def validate_environment(api_key):
} }
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -110,20 +128,18 @@ def completion(
for message in messages: for message in messages:
if "role" in message: if "role" in message:
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
## Load Config ## Load Config
config = litellm.AI21Config.get_config() config = litellm.AI21Config.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > ai21_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { data = {
@ -143,10 +159,7 @@ def completion(
api_base + model + "/complete", headers=headers, data=json.dumps(data) api_base + model + "/complete", headers=headers, data=json.dumps(data)
) )
if response.status_code != 200: if response.status_code != 200:
raise AI21Error( raise AI21Error(status_code=response.status_code, message=response.text)
status_code=response.status_code,
message=response.text
)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines() return response.iter_lines()
else: else:
@ -166,16 +179,20 @@ def completion(
message_obj = Message(content=item["data"]["text"]) message_obj = Message(content=item["data"]["text"])
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finishReason"]["reason"], index=idx+1, message=message_obj) choice_obj = Choices(
finish_reason=item["finishReason"]["reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
raise AI21Error(message=traceback.format_exc(), status_code=response.status_code) raise AI21Error(
message=traceback.format_exc(), status_code=response.status_code
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content")) encoding.encode(model_response["choices"][0]["message"].get("content"))
) )
@ -189,6 +206,7 @@ def completion(
} }
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -8,17 +8,21 @@ import litellm
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import httpx import httpx
class AlephAlphaError(Exception): class AlephAlphaError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.aleph-alpha.com/complete") self.request = httpx.Request(
method="POST", url="https://api.aleph-alpha.com/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AlephAlphaConfig():
class AlephAlphaConfig:
""" """
Reference: https://docs.aleph-alpha.com/api/complete/ Reference: https://docs.aleph-alpha.com/api/complete/
@ -72,83 +76,97 @@ class AlephAlphaConfig():
- `control_log_additive` (boolean; default value: true): Method of applying control to attention scores. - `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
""" """
maximum_tokens: Optional[int]=litellm.max_tokens # aleph alpha requires max tokens
minimum_tokens: Optional[int]=None
echo: Optional[bool]=None
temperature: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[int]=None
presence_penalty: Optional[int]=None
frequency_penalty: Optional[int]=None
sequence_penalty: Optional[int]=None
sequence_penalty_min_length: Optional[int]=None
repetition_penalties_include_prompt: Optional[bool]=None
repetition_penalties_include_completion: Optional[bool]=None
use_multiplicative_presence_penalty: Optional[bool]=None
use_multiplicative_frequency_penalty: Optional[bool]=None
use_multiplicative_sequence_penalty: Optional[bool]=None
penalty_bias: Optional[str]=None
penalty_exceptions_include_stop_sequences: Optional[bool]=None
best_of: Optional[int]=None
n: Optional[int]=None
logit_bias: Optional[dict]=None
log_probs: Optional[int]=None
stop_sequences: Optional[list]=None
tokens: Optional[bool]=None
raw_completion: Optional[bool]=None
disable_optimizations: Optional[bool]=None
completion_bias_inclusion: Optional[list]=None
completion_bias_exclusion: Optional[list]=None
completion_bias_inclusion_first_token_only: Optional[bool]=None
completion_bias_exclusion_first_token_only: Optional[bool]=None
contextual_control_threshold: Optional[int]=None
control_log_additive: Optional[bool]=None
maximum_tokens: Optional[
int
] = litellm.max_tokens # aleph alpha requires max tokens
minimum_tokens: Optional[int] = None
echo: Optional[bool] = None
temperature: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
presence_penalty: Optional[int] = None
frequency_penalty: Optional[int] = None
sequence_penalty: Optional[int] = None
sequence_penalty_min_length: Optional[int] = None
repetition_penalties_include_prompt: Optional[bool] = None
repetition_penalties_include_completion: Optional[bool] = None
use_multiplicative_presence_penalty: Optional[bool] = None
use_multiplicative_frequency_penalty: Optional[bool] = None
use_multiplicative_sequence_penalty: Optional[bool] = None
penalty_bias: Optional[str] = None
penalty_exceptions_include_stop_sequences: Optional[bool] = None
best_of: Optional[int] = None
n: Optional[int] = None
logit_bias: Optional[dict] = None
log_probs: Optional[int] = None
stop_sequences: Optional[list] = None
tokens: Optional[bool] = None
raw_completion: Optional[bool] = None
disable_optimizations: Optional[bool] = None
completion_bias_inclusion: Optional[list] = None
completion_bias_exclusion: Optional[list] = None
completion_bias_inclusion_first_token_only: Optional[bool] = None
completion_bias_exclusion_first_token_only: Optional[bool] = None
contextual_control_threshold: Optional[int] = None
control_log_additive: Optional[bool] = None
def __init__(self, def __init__(
maximum_tokens: Optional[int]=None, self,
minimum_tokens: Optional[int]=None, maximum_tokens: Optional[int] = None,
echo: Optional[bool]=None, minimum_tokens: Optional[int] = None,
temperature: Optional[int]=None, echo: Optional[bool] = None,
top_k: Optional[int]=None, temperature: Optional[int] = None,
top_p: Optional[int]=None, top_k: Optional[int] = None,
presence_penalty: Optional[int]=None, top_p: Optional[int] = None,
frequency_penalty: Optional[int]=None, presence_penalty: Optional[int] = None,
sequence_penalty: Optional[int]=None, frequency_penalty: Optional[int] = None,
sequence_penalty_min_length: Optional[int]=None, sequence_penalty: Optional[int] = None,
repetition_penalties_include_prompt: Optional[bool]=None, sequence_penalty_min_length: Optional[int] = None,
repetition_penalties_include_completion: Optional[bool]=None, repetition_penalties_include_prompt: Optional[bool] = None,
use_multiplicative_presence_penalty: Optional[bool]=None, repetition_penalties_include_completion: Optional[bool] = None,
use_multiplicative_frequency_penalty: Optional[bool]=None, use_multiplicative_presence_penalty: Optional[bool] = None,
use_multiplicative_sequence_penalty: Optional[bool]=None, use_multiplicative_frequency_penalty: Optional[bool] = None,
penalty_bias: Optional[str]=None, use_multiplicative_sequence_penalty: Optional[bool] = None,
penalty_exceptions_include_stop_sequences: Optional[bool]=None, penalty_bias: Optional[str] = None,
best_of: Optional[int]=None, penalty_exceptions_include_stop_sequences: Optional[bool] = None,
n: Optional[int]=None, best_of: Optional[int] = None,
logit_bias: Optional[dict]=None, n: Optional[int] = None,
log_probs: Optional[int]=None, logit_bias: Optional[dict] = None,
stop_sequences: Optional[list]=None, log_probs: Optional[int] = None,
tokens: Optional[bool]=None, stop_sequences: Optional[list] = None,
raw_completion: Optional[bool]=None, tokens: Optional[bool] = None,
disable_optimizations: Optional[bool]=None, raw_completion: Optional[bool] = None,
completion_bias_inclusion: Optional[list]=None, disable_optimizations: Optional[bool] = None,
completion_bias_exclusion: Optional[list]=None, completion_bias_inclusion: Optional[list] = None,
completion_bias_inclusion_first_token_only: Optional[bool]=None, completion_bias_exclusion: Optional[list] = None,
completion_bias_exclusion_first_token_only: Optional[bool]=None, completion_bias_inclusion_first_token_only: Optional[bool] = None,
contextual_control_threshold: Optional[int]=None, completion_bias_exclusion_first_token_only: Optional[bool] = None,
control_log_additive: Optional[bool]=None) -> None: contextual_control_threshold: Optional[int] = None,
control_log_additive: Optional[bool] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
@ -160,6 +178,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -179,7 +198,9 @@ def completion(
## Load Config ## Load Config
config = litellm.AlephAlphaConfig.get_config() config = litellm.AlephAlphaConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > aleph_alpha_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
completion_url = api_base completion_url = api_base
@ -188,21 +209,17 @@ def completion(
if "control" in model: # follow the ###Instruction / ###Response format if "control" in model: # follow the ###Instruction / ###Response format
for idx, message in enumerate(messages): for idx, message in enumerate(messages):
if "role" in message: if "role" in message:
if idx == 0: # set first message as instruction (required), let later user messages be input if (
idx == 0
): # set first message as instruction (required), let later user messages be input
prompt += f"###Instruction: {message['content']}" prompt += f"###Instruction: {message['content']}"
else: else:
if message["role"] == "system": if message["role"] == "system":
prompt += ( prompt += f"###Instruction: {message['content']}"
f"###Instruction: {message['content']}"
)
elif message["role"] == "user": elif message["role"] == "user":
prompt += ( prompt += f"###Input: {message['content']}"
f"###Input: {message['content']}"
)
else: else:
prompt += ( prompt += f"###Response: {message['content']}"
f"###Response: {message['content']}"
)
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
else: else:
@ -221,7 +238,10 @@ def completion(
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines() return response.iter_lines()
@ -249,16 +269,21 @@ def completion(
message_obj = Message(content=item["completion"]) message_obj = Message(content=item["completion"])
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except: except:
raise AlephAlphaError(message=json.dumps(completion_response), status_code=response.status_code) raise AlephAlphaError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"]) encoding.encode(model_response["choices"][0]["message"]["content"])
) )
@ -268,11 +293,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -9,52 +9,72 @@ import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx import httpx
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: " AI_PROMPT = "\n\nAssistant: "
class AnthropicError(Exception): class AnthropicError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.anthropic.com/v1/complete") self.request = httpx.Request(
method="POST", url="https://api.anthropic.com/v1/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AnthropicConfig():
class AnthropicConfig:
""" """
Reference: https://docs.anthropic.com/claude/reference/complete_post Reference: https://docs.anthropic.com/claude/reference/complete_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
""" """
max_tokens_to_sample: Optional[int]=litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list]=None
temperature: Optional[int]=None
top_p: Optional[int]=None
top_k: Optional[int]=None
metadata: Optional[dict]=None
def __init__(self, max_tokens_to_sample: Optional[
max_tokens_to_sample: Optional[int]=256, # anthropic requires a default int
stop_sequences: Optional[list]=None, ] = litellm.max_tokens # anthropic requires a default
temperature: Optional[int]=None, stop_sequences: Optional[list] = None
top_p: Optional[int]=None, temperature: Optional[int] = None
top_k: Optional[int]=None, top_p: Optional[int] = None
metadata: Optional[dict]=None) -> None: top_k: Optional[int] = None
metadata: Optional[dict] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# makes headers for API call # makes headers for API call
@ -71,6 +91,7 @@ def validate_environment(api_key):
} }
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -93,15 +114,19 @@ def completion(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic") prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## Load Config ## Load Config
config = litellm.AnthropicConfig.get_config() config = litellm.AnthropicConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { data = {
@ -127,15 +152,17 @@ def completion(
) )
if response.status_code != 200: if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text) raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines() return response.iter_lines()
else: else:
response = requests.post( response = requests.post(api_base, headers=headers, data=json.dumps(data))
api_base, headers=headers, data=json.dumps(data)
)
if response.status_code != 200: if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text) raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -159,9 +186,9 @@ def completion(
) )
else: else:
if len(completion_response["completion"]) > 0: if len(completion_response["completion"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response[ model_response["choices"][0]["message"][
"completion" "content"
] ] = completion_response["completion"]
model_response.choices[0].finish_reason = completion_response["stop_reason"] model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE ## CALCULATING USAGE
@ -177,11 +204,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -1,7 +1,13 @@
from typing import Optional, Union, Any from typing import Optional, Union, Any
import types, requests import types, requests
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
)
from typing import Callable, Optional from typing import Callable, Optional
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm, json import litellm, json
@ -9,8 +15,15 @@ import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI from openai import AzureOpenAI, AsyncAzureOpenAI
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
if request: if request:
@ -20,11 +33,14 @@ class AzureOpenAIError(Exception):
if response: if response:
self.response = response self.response = response
else: else:
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AzureOpenAIConfig(OpenAIConfig): class AzureOpenAIConfig(OpenAIConfig):
""" """
Reference: https://platform.openai.com/docs/api-reference/chat/create Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -52,18 +68,21 @@ class AzureOpenAIConfig(OpenAIConfig):
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
""" """
def __init__(self, def __init__(
self,
frequency_penalty: Optional[int] = None, frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]]= None, function_call: Optional[Union[str, dict]] = None,
functions: Optional[list]= None, functions: Optional[list] = None,
logit_bias: Optional[dict]= None, logit_bias: Optional[dict] = None,
max_tokens: Optional[int]= None, max_tokens: Optional[int] = None,
n: Optional[int]= None, n: Optional[int] = None,
presence_penalty: Optional[int]= None, presence_penalty: Optional[int] = None,
stop: Optional[Union[str,list]]=None, stop: Optional[Union[str, list]] = None,
temperature: Optional[int]= None, temperature: Optional[int] = None,
top_p: Optional[int]= None) -> None: top_p: Optional[int] = None,
super().__init__(frequency_penalty, ) -> None:
super().__init__(
frequency_penalty,
function_call, function_call,
functions, functions,
logit_bias, logit_bias,
@ -72,10 +91,11 @@ class AzureOpenAIConfig(OpenAIConfig):
presence_penalty, presence_penalty,
stop, stop,
temperature, temperature,
top_p) top_p,
)
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -89,7 +109,8 @@ class AzureChatCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {azure_ad_token}" headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers return headers
def completion(self, def completion(
self,
model: str, model: str,
messages: list, messages: list,
model_response: ModelResponse, model_response: ModelResponse,
@ -105,15 +126,16 @@ class AzureChatCompletion(BaseLLM):
litellm_params, litellm_params,
logger_fn, logger_fn,
acompletion: bool = False, acompletion: bool = False,
headers: Optional[dict]=None, headers: Optional[dict] = None,
client = None, client=None,
): ):
super().completion() super().completion()
exception_mapping_worked = False exception_mapping_worked = False
try: try:
if model is None or messages is None: if model is None or messages is None:
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages") raise AzureOpenAIError(
status_code=422, message=f"Missing model or messages"
)
max_retries = optional_params.pop("max_retries", 2) max_retries = optional_params.pop("max_retries", 2)
@ -131,7 +153,7 @@ class AzureChatCompletion(BaseLLM):
"base_url": f"{api_base}", "base_url": f"{api_base}",
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -143,25 +165,52 @@ class AzureChatCompletion(BaseLLM):
else: else:
client = AzureOpenAI(**azure_client_params) client = AzureOpenAI(**azure_client_params)
data = { data = {"model": None, "messages": messages, **optional_params}
"model": None,
"messages": messages,
**optional_params
}
else: else:
data = { data = {
"model": model, # type: ignore "model": model, # type: ignore
"messages": messages, "messages": messages,
**optional_params **optional_params,
} }
if acompletion is True: if acompletion is True:
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client) return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else: else:
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, client=client, logging_obj=logging_obj) return self.acompletion(
api_base=api_base,
data=data,
model_response=model_response,
api_key=api_key,
api_version=api_version,
model=model,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
logging_obj=logging_obj,
)
elif "stream" in optional_params and optional_params["stream"] == True: elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client) return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else: else:
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -170,7 +219,7 @@ class AzureChatCompletion(BaseLLM):
additional_args={ additional_args={
"headers": { "headers": {
"api_key": api_key, "api_key": api_key,
"azure_ad_token": azure_ad_token "azure_ad_token": azure_ad_token,
}, },
"api_version": api_version, "api_version": api_version,
"api_base": api_base, "api_base": api_base,
@ -178,7 +227,9 @@ class AzureChatCompletion(BaseLLM):
}, },
) )
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
@ -186,7 +237,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -209,14 +260,18 @@ class AzureChatCompletion(BaseLLM):
"api_base": api_base, "api_base": api_base,
}, },
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
)
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
raise e raise e
async def acompletion(self, async def acompletion(
self,
api_key: str, api_key: str,
api_version: str, api_version: str,
model: str, model: str,
@ -224,15 +279,17 @@ class AzureChatCompletion(BaseLLM):
data: dict, data: dict,
timeout: Any, timeout: Any,
model_response: ModelResponse, model_response: ModelResponse,
azure_ad_token: Optional[str]=None, azure_ad_token: Optional[str] = None,
client = None, # this is the AsyncAzureOpenAI client=None, # this is the AsyncAzureOpenAI
logging_obj=None, logging_obj=None,
): ):
response = None response = None
try: try:
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
@ -240,7 +297,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -252,12 +309,20 @@ class AzureChatCompletion(BaseLLM):
azure_client = client azure_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data['messages'], input=data["messages"],
api_key=azure_client.api_key, api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
) )
response = await azure_client.chat.completions.create(**data) response = await azure_client.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) return convert_to_model_response_object(
response_object=json.loads(response.model_dump_json()),
model_response_object=model_response,
)
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
@ -267,7 +332,8 @@ class AzureChatCompletion(BaseLLM):
else: else:
raise AzureOpenAIError(status_code=500, message=str(e)) raise AzureOpenAIError(status_code=500, message=str(e))
def streaming(self, def streaming(
self,
logging_obj, logging_obj,
api_base: str, api_base: str,
api_key: str, api_key: str,
@ -275,12 +341,14 @@ class AzureChatCompletion(BaseLLM):
data: dict, data: dict,
model: str, model: str,
timeout: Any, timeout: Any,
azure_ad_token: Optional[str]=None, azure_ad_token: Optional[str] = None,
client=None, client=None,
): ):
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
"api_version": api_version, "api_version": api_version,
@ -288,7 +356,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -300,15 +368,26 @@ class AzureChatCompletion(BaseLLM):
azure_client = client azure_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data['messages'], input=data["messages"],
api_key=azure_client.api_key, api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
) )
response = azure_client.chat.completions.create(**data) response = azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
return streamwrapper return streamwrapper
async def async_streaming(self, async def async_streaming(
self,
logging_obj, logging_obj,
api_base: str, api_base: str,
api_key: str, api_key: str,
@ -316,8 +395,8 @@ class AzureChatCompletion(BaseLLM):
data: dict, data: dict,
model: str, model: str,
timeout: Any, timeout: Any,
azure_ad_token: Optional[str]=None, azure_ad_token: Optional[str] = None,
client = None, client=None,
): ):
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
@ -326,7 +405,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": data.pop("max_retries", 2), "max_retries": data.pop("max_retries", 2),
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -338,12 +417,22 @@ class AzureChatCompletion(BaseLLM):
azure_client = client azure_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data['messages'], input=data["messages"],
api_key=azure_client.api_key, api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
) )
response = await azure_client.chat.completions.create(**data) response = await azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
@ -355,7 +444,7 @@ class AzureChatCompletion(BaseLLM):
api_key: str, api_key: str,
input: list, input: list,
client=None, client=None,
logging_obj=None logging_obj=None,
): ):
response = None response = None
try: try:
@ -372,7 +461,11 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
response_type="embedding",
)
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -383,7 +476,8 @@ class AzureChatCompletion(BaseLLM):
) )
raise e raise e
def embedding(self, def embedding(
self,
model: str, model: str,
input: list, input: list,
api_key: str, api_key: str,
@ -393,8 +487,8 @@ class AzureChatCompletion(BaseLLM):
logging_obj=None, logging_obj=None,
model_response=None, model_response=None,
optional_params=None, optional_params=None,
azure_ad_token: Optional[str]=None, azure_ad_token: Optional[str] = None,
client = None, client=None,
aembedding=None, aembedding=None,
): ):
super().embedding() super().embedding()
@ -402,14 +496,12 @@ class AzureChatCompletion(BaseLLM):
if self._client_session is None: if self._client_session is None:
self._client_session = self.create_client_session() self._client_session = self.create_client_session()
try: try:
data = { data = {"model": model, "input": input, **optional_params}
"model": model,
"input": input,
**optional_params
}
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
@ -418,7 +510,7 @@ class AzureChatCompletion(BaseLLM):
"azure_deployment": model, "azure_deployment": model,
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -431,15 +523,19 @@ class AzureChatCompletion(BaseLLM):
api_key=api_key, api_key=api_key,
additional_args={ additional_args={
"complete_input_dict": data, "complete_input_dict": data,
"headers": { "headers": {"api_key": api_key, "azure_ad_token": azure_ad_token},
"api_key": api_key,
"azure_ad_token": azure_ad_token
}
}, },
) )
if aembedding == True: if aembedding == True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params) response = self.aembedding(
data=data,
input=input,
logging_obj=logging_obj,
api_key=api_key,
model_response=model_response,
azure_client_params=azure_client_params,
)
return response return response
if client is None: if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore azure_client = AzureOpenAI(**azure_client_params) # type: ignore
@ -455,7 +551,6 @@ class AzureChatCompletion(BaseLLM):
original_response=response, original_response=response,
) )
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") # type: ignore
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
@ -465,6 +560,7 @@ class AzureChatCompletion(BaseLLM):
raise e raise e
else: else:
import traceback import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def aimage_generation( async def aimage_generation(
@ -475,13 +571,17 @@ class AzureChatCompletion(BaseLLM):
api_key: str, api_key: str,
input: list, input: list,
client=None, client=None,
logging_obj=None logging_obj=None,
): ):
response = None response = None
try: try:
if client is None: if client is None:
client_session = litellm.aclient_session or httpx.AsyncClient(transport=AsyncCustomHTTPTransport(),) client_session = litellm.aclient_session or httpx.AsyncClient(
openai_aclient = AsyncAzureOpenAI(http_client=client_session, **azure_client_params) transport=AsyncCustomHTTPTransport(),
)
openai_aclient = AsyncAzureOpenAI(
http_client=client_session, **azure_client_params
)
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.images.generate(**data) response = await openai_aclient.images.generate(**data)
@ -493,7 +593,11 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="image_generation") return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
response_type="image_generation",
)
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -504,15 +608,16 @@ class AzureChatCompletion(BaseLLM):
) )
raise e raise e
def image_generation(self, def image_generation(
self,
prompt: str, prompt: str,
timeout: float, timeout: float,
model: Optional[str]=None, model: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None, model_response: Optional[litellm.utils.ImageResponse] = None,
azure_ad_token: Optional[str]=None, azure_ad_token: Optional[str] = None,
logging_obj=None, logging_obj=None,
optional_params=None, optional_params=None,
client=None, client=None,
@ -524,14 +629,12 @@ class AzureChatCompletion(BaseLLM):
model = model model = model
else: else:
model = None model = None
data = { data = {"model": model, "prompt": prompt, **optional_params}
"model": model,
"prompt": prompt,
**optional_params
}
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise AzureOpenAIError(status_code=422, message="max retries must be an int") raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params = {
@ -539,7 +642,7 @@ class AzureChatCompletion(BaseLLM):
"azure_endpoint": api_base, "azure_endpoint": api_base,
"azure_deployment": model, "azure_deployment": model,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout "timeout": timeout,
} }
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
@ -551,7 +654,9 @@ class AzureChatCompletion(BaseLLM):
return response return response
if client is None: if client is None:
client_session = litellm.client_session or httpx.Client(transport=CustomHTTPTransport(),) client_session = litellm.client_session or httpx.Client(
transport=CustomHTTPTransport(),
)
azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore azure_client = AzureOpenAI(http_client=client_session, **azure_client_params) # type: ignore
else: else:
azure_client = client azure_client = client
@ -560,7 +665,12 @@ class AzureChatCompletion(BaseLLM):
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=azure_client.api_key, api_key=azure_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data}, additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": False,
"complete_input_dict": data,
},
) )
## COMPLETION CALL ## COMPLETION CALL
@ -582,4 +692,5 @@ class AzureChatCompletion(BaseLLM):
raise e raise e
else: else:
import traceback import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) raise AzureOpenAIError(status_code=500, message=traceback.format_exc())

View file

@ -3,8 +3,10 @@ import litellm
import httpx, certifi, ssl import httpx, certifi, ssl
from typing import Optional from typing import Optional
class BaseLLM: class BaseLLM:
_client_session: Optional[httpx.Client] = None _client_session: Optional[httpx.Client] = None
def create_client_session(self): def create_client_session(self):
if litellm.client_session: if litellm.client_session:
_client_session = litellm.client_session _client_session = litellm.client_session
@ -22,26 +24,22 @@ class BaseLLM:
return _aclient_session return _aclient_session
def __exit__(self): def __exit__(self):
if hasattr(self, '_client_session'): if hasattr(self, "_client_session"):
self._client_session.close() self._client_session.close()
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if hasattr(self, '_aclient_session'): if hasattr(self, "_aclient_session"):
await self._aclient_session.aclose() await self._aclient_session.aclose()
def validate_environment(self): # set up the environment required to run the model def validate_environment(self): # set up the environment required to run the model
pass pass
def completion( def completion(
self, self, *args, **kwargs
*args,
**kwargs
): # logic for parsing in - calling - parsing out model completion calls ): # logic for parsing in - calling - parsing out model completion calls
pass pass
def embedding( def embedding(
self, self, *args, **kwargs
*args,
**kwargs
): # logic for parsing in - calling - parsing out model embedding calls ): # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -6,6 +6,7 @@ import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
class BasetenError(Exception): class BasetenError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -14,6 +15,7 @@ class BasetenError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def validate_environment(api_key): def validate_environment(api_key):
headers = { headers = {
"accept": "application/json", "accept": "application/json",
@ -23,6 +25,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Api-Key {api_key}" headers["Authorization"] = f"Api-Key {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -52,7 +55,9 @@ def completion(
"inputs": prompt, "inputs": prompt,
"prompt": prompt, "prompt": prompt,
"parameters": optional_params, "parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False "stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
} }
## LOGGING ## LOGGING
@ -66,9 +71,13 @@ def completion(
completion_url_fragment_1 + model + completion_url_fragment_2, completion_url_fragment_1 + model + completion_url_fragment_2,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=True if "stream" in optional_params and optional_params["stream"] == True else False stream=True
if "stream" in optional_params and optional_params["stream"] == True
else False,
) )
if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True): if "text/event-stream" in response.headers["Content-Type"] or (
"stream" in optional_params and optional_params["stream"] == True
):
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
@ -91,9 +100,7 @@ def completion(
if ( if (
isinstance(completion_response["model_output"], dict) isinstance(completion_response["model_output"], dict)
and "data" in completion_response["model_output"] and "data" in completion_response["model_output"]
and isinstance( and isinstance(completion_response["model_output"]["data"], list)
completion_response["model_output"]["data"], list
)
): ):
model_response["choices"][0]["message"][ model_response["choices"][0]["message"][
"content" "content"
@ -112,12 +119,19 @@ def completion(
if "generated_text" not in completion_response: if "generated_text" not in completion_response:
raise BasetenError( raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}", message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code status_code=response.status_code,
) )
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] model_response["choices"][0]["message"][
"content"
] = completion_response[0]["generated_text"]
## GETTING LOGPROBS ## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: if (
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] "details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0 sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]: for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
@ -125,7 +139,7 @@ def completion(
else: else:
raise BasetenError( raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}", message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code status_code=response.status_code,
) )
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
@ -139,11 +153,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, get_secret, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx import httpx
class BedrockError(Exception): class BedrockError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock") self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class AmazonTitanConfig():
class AmazonTitanConfig:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1 Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
@ -29,29 +33,44 @@ class AmazonTitanConfig():
- `temperature` (float) temperature for model, - `temperature` (float) temperature for model,
- `topP` (int) top p for model - `topP` (int) top p for model
""" """
maxTokenCount: Optional[int]=None
stopSequences: Optional[list]=None
temperature: Optional[float]=None
topP: Optional[int]=None
def __init__(self, maxTokenCount: Optional[int] = None
maxTokenCount: Optional[int]=None, stopSequences: Optional[list] = None
stopSequences: Optional[list]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, topP: Optional[int] = None
topP: Optional[int]=None) -> None:
def __init__(
self,
maxTokenCount: Optional[int] = None,
stopSequences: Optional[list] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAnthropicConfig():
class AmazonAnthropicConfig:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
@ -64,33 +83,48 @@ class AmazonAnthropicConfig():
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"], - `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" - `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
""" """
max_tokens_to_sample: Optional[int]=litellm.max_tokens
stop_sequences: Optional[list]=None
temperature: Optional[float]=None
top_k: Optional[int]=None
top_p: Optional[int]=None
anthropic_version: Optional[str]=None
def __init__(self, max_tokens_to_sample: Optional[int] = litellm.max_tokens
max_tokens_to_sample: Optional[int]=None, stop_sequences: Optional[list] = None
stop_sequences: Optional[list]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, top_k: Optional[int] = None
top_k: Optional[int]=None, top_p: Optional[int] = None
top_p: Optional[int]=None, anthropic_version: Optional[str] = None
anthropic_version: Optional[str]=None) -> None:
def __init__(
self,
max_tokens_to_sample: Optional[int] = None,
stop_sequences: Optional[list] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonCohereConfig():
class AmazonCohereConfig:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
@ -100,27 +134,42 @@ class AmazonCohereConfig():
- `temperature` (float) model temperature, - `temperature` (float) model temperature,
- `return_likelihood` (string) n/a - `return_likelihood` (string) n/a
""" """
max_tokens: Optional[int]=None
temperature: Optional[float]=None
return_likelihood: Optional[str]=None
def __init__(self, max_tokens: Optional[int] = None
max_tokens: Optional[int]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, return_likelihood: Optional[str] = None
return_likelihood: Optional[str]=None) -> None:
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
return_likelihood: Optional[str] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAI21Config():
class AmazonAI21Config:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
@ -140,39 +189,55 @@ class AmazonAI21Config():
- `countPenalty` (object): Placeholder for count penalty object. - `countPenalty` (object): Placeholder for count penalty object.
""" """
maxTokens: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
stopSequences: Optional[list]=None
frequencePenalty: Optional[dict]=None
presencePenalty: Optional[dict]=None
countPenalty: Optional[dict]=None
def __init__(self, maxTokens: Optional[int] = None
maxTokens: Optional[int]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, topP: Optional[float] = None
topP: Optional[float]=None, stopSequences: Optional[list] = None
stopSequences: Optional[list]=None, frequencePenalty: Optional[dict] = None
frequencePenalty: Optional[dict]=None, presencePenalty: Optional[dict] = None
presencePenalty: Optional[dict]=None, countPenalty: Optional[dict] = None
countPenalty: Optional[dict]=None) -> None:
def __init__(
self,
maxTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: " AI_PROMPT = "\n\nAssistant: "
class AmazonLlamaConfig():
class AmazonLlamaConfig:
""" """
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1 Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
@ -182,48 +247,72 @@ class AmazonLlamaConfig():
- `temperature` (float) temperature for model, - `temperature` (float) temperature for model,
- `top_p` (float) top p for model - `top_p` (float) top p for model
""" """
max_gen_len: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
def __init__(self, max_gen_len: Optional[int] = None
maxTokenCount: Optional[int]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, topP: Optional[float] = None
topP: Optional[int]=None) -> None:
def __init__(
self,
maxTokenCount: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def init_bedrock_client( def init_bedrock_client(
region_name = None, region_name=None,
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None, aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] =None, aws_region_name: Optional[str] = None,
aws_bedrock_runtime_endpoint: Optional[str]=None, aws_bedrock_runtime_endpoint: Optional[str] = None,
): ):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None) standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in ## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check # Define the list of parameters to check
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint] params_to_check = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
]
# Iterate over parameters and update if needed # Iterate over parameters and update if needed
for i, param in enumerate(params_to_check): for i, param in enumerate(params_to_check):
if param and param.startswith('os.environ/'): if param and param.startswith("os.environ/"):
params_to_check[i] = get_secret(param) params_to_check[i] = get_secret(param)
# Assign updated values back to parameters # Assign updated values back to parameters
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check (
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
) = params_to_check
if region_name: if region_name:
pass pass
elif aws_region_name: elif aws_region_name:
@ -233,7 +322,10 @@ def init_bedrock_client(
elif standard_aws_region_name: elif standard_aws_region_name:
region_name = standard_aws_region_name region_name = standard_aws_region_name
else: else:
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401) raise BedrockError(
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
status_code=401,
)
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client # check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
@ -242,9 +334,10 @@ def init_bedrock_client(
elif env_aws_bedrock_runtime_endpoint: elif env_aws_bedrock_runtime_endpoint:
endpoint_url = env_aws_bedrock_runtime_endpoint endpoint_url = env_aws_bedrock_runtime_endpoint
else: else:
endpoint_url = f'https://bedrock-runtime.{region_name}.amazonaws.com' endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
import boto3 import boto3
if aws_access_key_id != None: if aws_access_key_id != None:
# uses auth params passed to completion # uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion # aws_access_key_id is not None, assume user is trying to auth using litellm.completion
@ -279,22 +372,20 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="anthropic") prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
else: else:
prompt = "" prompt = ""
for message in messages: for message in messages:
if "role" in message: if "role" in message:
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
return prompt return prompt
@ -309,6 +400,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
# set os.environ['AWS_REGION_NAME'] = <your-region_name> # set os.environ['AWS_REGION_NAME'] = <your-region_name>
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -327,7 +419,9 @@ def completion(
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None) aws_region_name = optional_params.pop("aws_region_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None) aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop( client = optional_params.pop(
@ -343,67 +437,71 @@ def completion(
model = model model = model
provider = model.split(".")[0] provider = model.split(".")[0]
prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) prompt = convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False) stream = inference_params.pop("stream", False)
if provider == "anthropic": if provider == "anthropic":
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config() config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
data = json.dumps({ data = json.dumps({"prompt": prompt, **inference_params})
"prompt": prompt,
**inference_params
})
elif provider == "ai21": elif provider == "ai21":
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config() config = litellm.AmazonAI21Config.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
data = json.dumps({ data = json.dumps({"prompt": prompt, **inference_params})
"prompt": prompt,
**inference_params
})
elif provider == "cohere": elif provider == "cohere":
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config() config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:
inference_params["stream"] = True # cohere requires stream = True in inference params inference_params[
data = json.dumps({ "stream"
"prompt": prompt, ] = True # cohere requires stream = True in inference params
**inference_params data = json.dumps({"prompt": prompt, **inference_params})
})
elif provider == "meta": elif provider == "meta":
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config() config = litellm.AmazonLlamaConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
data = json.dumps({ data = json.dumps({"prompt": prompt, **inference_params})
"prompt": prompt,
**inference_params
})
elif provider == "amazon": # amazon titan elif provider == "amazon": # amazon titan
## LOAD CONFIG ## LOAD CONFIG
config = litellm.AmazonTitanConfig.get_config() config = litellm.AmazonTitanConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
data = json.dumps({ data = json.dumps(
{
"inputText": prompt, "inputText": prompt,
"textGenerationConfig": inference_params, "textGenerationConfig": inference_params,
}) }
)
## COMPLETION CALL ## COMPLETION CALL
accept = 'application/json' accept = "application/json"
contentType = 'application/json' contentType = "application/json"
if stream == True: if stream == True:
if provider == "ai21": if provider == "ai21":
## LOGGING ## LOGGING
@ -418,17 +516,17 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str}, additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
) )
response = client.invoke_model( response = client.invoke_model(
body=data, body=data, modelId=model, accept=accept, contentType=contentType
modelId=model,
accept=accept,
contentType=contentType
) )
response = response.get('body').read() response = response.get("body").read()
return response return response
else: else:
## LOGGING ## LOGGING
@ -443,16 +541,16 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str}, additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
) )
response = client.invoke_model_with_response_stream( response = client.invoke_model_with_response_stream(
body=data, body=data, modelId=model, accept=accept, contentType=contentType
modelId=model,
accept=accept,
contentType=contentType
) )
response = response.get('body') response = response.get("body")
return response return response
try: try:
## LOGGING ## LOGGING
@ -467,18 +565,18 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str}, additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
) )
response = client.invoke_model( response = client.invoke_model(
body=data, body=data, modelId=model, accept=accept, contentType=contentType
modelId=model,
accept=accept,
contentType=contentType
) )
except Exception as e: except Exception as e:
raise BedrockError(status_code=500, message=str(e)) raise BedrockError(status_code=500, message=str(e))
response_body = json.loads(response.get('body').read()) response_body = json.loads(response.get("body").read())
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -491,16 +589,16 @@ def completion(
## RESPONSE OBJECT ## RESPONSE OBJECT
outputText = "default" outputText = "default"
if provider == "ai21": if provider == "ai21":
outputText = response_body.get('completions')[0].get('data').get('text') outputText = response_body.get("completions")[0].get("data").get("text")
elif provider == "anthropic": elif provider == "anthropic":
outputText = response_body['completion'] outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"] model_response["finish_reason"] = response_body["stop_reason"]
elif provider == "cohere": elif provider == "cohere":
outputText = response_body["generations"][0]["text"] outputText = response_body["generations"][0]["text"]
elif provider == "meta": elif provider == "meta":
outputText = response_body["generation"] outputText = response_body["generation"]
else: # amazon titan else: # amazon titan
outputText = response_body.get('results')[0].get('outputText') outputText = response_body.get("results")[0].get("outputText")
response_metadata = response.get("ResponseMetadata", {}) response_metadata = response.get("ResponseMetadata", {})
if response_metadata.get("HTTPStatusCode", 500) >= 400: if response_metadata.get("HTTPStatusCode", 500) >= 400:
@ -513,12 +611,13 @@ def completion(
if len(outputText) > 0: if len(outputText) > 0:
model_response["choices"][0]["message"]["content"] = outputText model_response["choices"][0]["message"]["content"] = outputText
except: except:
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500)) raise BedrockError(
message=json.dumps(outputText),
status_code=response_metadata.get("HTTPStatusCode", 500),
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -528,7 +627,7 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens = prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
@ -540,8 +639,10 @@ def completion(
raise e raise e
else: else:
import traceback import traceback
raise BedrockError(status_code=500, message=traceback.format_exc()) raise BedrockError(status_code=500, message=traceback.format_exc())
def _embedding_func_single( def _embedding_func_single(
model: str, model: str,
input: str, input: str,
@ -554,13 +655,17 @@ def _embedding_func_single(
## FORMAT EMBEDDING INPUT ## ## FORMAT EMBEDDING INPUT ##
provider = model.split(".")[0] provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
inference_params.pop("user", None) # make sure user is not passed in for bedrock call inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
if provider == "amazon": if provider == "amazon":
input = input.replace(os.linesep, " ") input = input.replace(os.linesep, " ")
data = {"inputText": input, **inference_params} data = {"inputText": input, **inference_params}
# data = json.dumps(data) # data = json.dumps(data)
elif provider == "cohere": elif provider == "cohere":
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 inference_params["input_type"] = inference_params.get(
"input_type", "search_document"
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
data = {"texts": [input], **inference_params} # type: ignore data = {"texts": [input], **inference_params} # type: ignore
body = json.dumps(data).encode("utf-8") body = json.dumps(data).encode("utf-8")
## LOGGING ## LOGGING
@ -574,8 +679,10 @@ def _embedding_func_single(
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key="", # boto3 is used for init. api_key="", # boto3 is used for init.
additional_args={"complete_input_dict": {"model": model, additional_args={
"texts": input}, "request_str": request_str}, "complete_input_dict": {"model": model, "texts": input},
"request_str": request_str,
},
) )
try: try:
response = client.invoke_model( response = client.invoke_model(
@ -600,7 +707,10 @@ def _embedding_func_single(
elif provider == "amazon": elif provider == "amazon":
return response_body.get("embedding") return response_body.get("embedding")
except Exception as e: except Exception as e:
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500) raise BedrockError(
message=f"Embedding Error with model {model}: {e}", status_code=500
)
def embedding( def embedding(
model: str, model: str,
@ -616,7 +726,9 @@ def embedding(
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None) aws_region_name = optional_params.pop("aws_region_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop("aws_bedrock_runtime_endpoint", None) aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client( client = init_bedrock_client(
@ -627,8 +739,16 @@ def embedding(
) )
## Embedding Call ## Embedding Call
embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls embeddings = [
_embedding_func_single(
model,
i,
optional_params=optional_params,
client=client,
logging_obj=logging_obj,
)
for i in input
] # [TODO]: make these parallel calls
## Populate OpenAI compliant dictionary ## Populate OpenAI compliant dictionary
embedding_response = [] embedding_response = []
@ -647,12 +767,10 @@ def embedding(
input_str = "".join(input) input_str = "".join(input)
input_tokens+=len(encoding.encode(input_str)) input_tokens += len(encoding.encode(input_str))
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0
completion_tokens=0,
total_tokens=input_tokens + 0
) )
model_response.usage = usage model_response.usage = usage

View file

@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
import httpx import httpx
class CohereError(Exception): class CohereError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/generate") self.request = httpx.Request(
method="POST", url="https://api.cohere.ai/v1/generate"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class CohereConfig():
class CohereConfig:
""" """
Reference: https://docs.cohere.com/reference/generate Reference: https://docs.cohere.com/reference/generate
@ -50,46 +54,60 @@ class CohereConfig():
- `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233} - `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233}
""" """
num_generations: Optional[int]=None
max_tokens: Optional[int]=None
truncate: Optional[str]=None
temperature: Optional[int]=None
preset: Optional[str]=None
end_sequences: Optional[list]=None
stop_sequences: Optional[list]=None
k: Optional[int]=None
p: Optional[int]=None
frequency_penalty: Optional[int]=None
presence_penalty: Optional[int]=None
return_likelihoods: Optional[str]=None
logit_bias: Optional[dict]=None
def __init__(self, num_generations: Optional[int] = None
num_generations: Optional[int]=None, max_tokens: Optional[int] = None
max_tokens: Optional[int]=None, truncate: Optional[str] = None
truncate: Optional[str]=None, temperature: Optional[int] = None
temperature: Optional[int]=None, preset: Optional[str] = None
preset: Optional[str]=None, end_sequences: Optional[list] = None
end_sequences: Optional[list]=None, stop_sequences: Optional[list] = None
stop_sequences: Optional[list]=None, k: Optional[int] = None
k: Optional[int]=None, p: Optional[int] = None
p: Optional[int]=None, frequency_penalty: Optional[int] = None
frequency_penalty: Optional[int]=None, presence_penalty: Optional[int] = None
presence_penalty: Optional[int]=None, return_likelihoods: Optional[str] = None
return_likelihoods: Optional[str]=None, logit_bias: Optional[dict] = None
logit_bias: Optional[dict]=None) -> None:
def __init__(
self,
num_generations: Optional[int] = None,
max_tokens: Optional[int] = None,
truncate: Optional[str] = None,
temperature: Optional[int] = None,
preset: Optional[str] = None,
end_sequences: Optional[list] = None,
stop_sequences: Optional[list] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
return_likelihoods: Optional[str] = None,
logit_bias: Optional[dict] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
headers = { headers = {
@ -100,6 +118,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -119,9 +138,11 @@ def completion(
prompt = " ".join(message["content"] for message in messages) prompt = " ".join(message["content"] for message in messages)
## Load Config ## Load Config
config=litellm.CohereConfig.get_config() config = litellm.CohereConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { data = {
@ -134,14 +155,21 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url}, additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
) )
## error handling for cohere calls ## error handling for cohere calls
if response.status_code!=200: if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
@ -170,16 +198,20 @@ def completion(
message_obj = Message(content=item["text"]) message_obj = Message(content=item["text"])
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -189,11 +221,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding( def embedding(
model: str, model: str,
input: list, input: list,
@ -206,11 +239,7 @@ def embedding(
headers = validate_environment(api_key) headers = validate_environment(api_key)
embed_url = "https://api.cohere.ai/v1/embed" embed_url = "https://api.cohere.ai/v1/embed"
model = model model = model
data = { data = {"model": model, "texts": input, **optional_params}
"model": model,
"texts": input,
**optional_params
}
if "3" in model and "input_type" not in data: if "3" in model and "input_type" not in data:
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document" # cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
@ -223,9 +252,7 @@ def embedding(
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(embed_url, headers=headers, data=json.dumps(data))
embed_url, headers=headers, data=json.dumps(data)
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
@ -244,30 +271,23 @@ def embedding(
'usage' 'usage'
} }
""" """
if response.status_code!=200: if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
embeddings = response.json()['embeddings'] embeddings = response.json()["embeddings"]
output_data = [] output_data = []
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
output_data.append( output_data.append(
{ {"object": "embedding", "index": idx, "embedding": embedding}
"object": "embedding",
"index": idx,
"embedding": embedding
}
) )
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = output_data model_response["data"] = output_data
model_response["model"] = model model_response["model"] = model
input_tokens = 0 input_tokens = 0
for text in input: for text in input:
input_tokens+=len(encoding.encode(text)) input_tokens += len(encoding.encode(text))
model_response["usage"] = { model_response["usage"] = {
"prompt_tokens": input_tokens, "prompt_tokens": input_tokens,
"total_tokens": input_tokens, "total_tokens": input_tokens,
} }
return model_response return model_response

View file

@ -1,9 +1,11 @@
import time, json, httpx, asyncio import time, json, httpx, asyncio
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
""" """
Async implementation of custom http transport Async implementation of custom http transport
""" """
async def handle_async_request(self, request: httpx.Request) -> httpx.Response: async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[ if "images/generations" in request.url.path and request.url.params[
"api-version" "api-version"
@ -14,7 +16,9 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
"2023-09-01-preview", "2023-09-01-preview",
"2023-10-01-preview", "2023-10-01-preview",
]: ]:
request.url = request.url.copy_with(path="/openai/images/generations:submit") request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
response = await super().handle_async_request(request) response = await super().handle_async_request(request)
operation_location_url = response.headers["operation-location"] operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url) request.url = httpx.URL(operation_location_url)
@ -26,7 +30,12 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
start_time = time.time() start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]: while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs: if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} timeout = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
return httpx.Response( return httpx.Response(
status_code=400, status_code=400,
headers=response.headers, headers=response.headers,
@ -56,12 +65,14 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
) )
return await super().handle_async_request(request) return await super().handle_async_request(request)
class CustomHTTPTransport(httpx.HTTPTransport): class CustomHTTPTransport(httpx.HTTPTransport):
""" """
This class was written as a workaround to support dall-e-2 on openai > v1.x This class was written as a workaround to support dall-e-2 on openai > v1.x
Refer to this issue for more: https://github.com/openai/openai-python/issues/692 Refer to this issue for more: https://github.com/openai/openai-python/issues/692
""" """
def handle_request( def handle_request(
self, self,
request: httpx.Request, request: httpx.Request,
@ -75,7 +86,9 @@ class CustomHTTPTransport(httpx.HTTPTransport):
"2023-09-01-preview", "2023-09-01-preview",
"2023-10-01-preview", "2023-10-01-preview",
]: ]:
request.url = request.url.copy_with(path="/openai/images/generations:submit") request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
response = super().handle_request(request) response = super().handle_request(request)
operation_location_url = response.headers["operation-location"] operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url) request.url = httpx.URL(operation_location_url)
@ -87,7 +100,12 @@ class CustomHTTPTransport(httpx.HTTPTransport):
start_time = time.time() start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]: while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs: if time.time() - start_time > timeout_secs:
timeout = {"error": {"code": "Timeout", "message": "Operation polling timed out."}} timeout = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
return httpx.Response( return httpx.Response(
status_code=400, status_code=400,
headers=response.headers, headers=response.headers,

View file

@ -8,17 +8,22 @@ import litellm
import sys, httpx import sys, httpx
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class GeminiError(Exception): class GeminiError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat") self.request = httpx.Request(
method="POST",
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class GeminiConfig():
class GeminiConfig:
""" """
Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig
@ -37,33 +42,44 @@ class GeminiConfig():
- `top_k` (int): Optional. The maximum number of tokens to consider when sampling. - `top_k` (int): Optional. The maximum number of tokens to consider when sampling.
""" """
candidate_count: Optional[int]=None candidate_count: Optional[int] = None
stop_sequences: Optional[list]=None stop_sequences: Optional[list] = None
max_output_tokens: Optional[int]=None max_output_tokens: Optional[int] = None
temperature: Optional[float]=None temperature: Optional[float] = None
top_p: Optional[float]=None top_p: Optional[float] = None
top_k: Optional[int]=None top_k: Optional[int] = None
def __init__(self,
candidate_count: Optional[int]=None,
stop_sequences: Optional[list]=None,
max_output_tokens: Optional[int]=None,
temperature: Optional[float]=None,
top_p: Optional[float]=None,
top_k: Optional[int]=None) -> None:
def __init__(
self,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion( def completion(
@ -83,10 +99,11 @@ def completion(
try: try:
import google.generativeai as genai import google.generativeai as genai
except: except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
@ -94,21 +111,25 @@ def completion(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="gemini") prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="gemini"
)
## Load Config ## Load Config
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py inference_params.pop(
"stream", None
) # palm does not support streaming, so we handle this by fake streaming in main.py
config = litellm.GeminiConfig.get_config() config = litellm.GeminiConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -117,8 +138,11 @@ def completion(
) )
## COMPLETION CALL ## COMPLETION CALL
try: try:
_model = genai.GenerativeModel(f'models/{model}') _model = genai.GenerativeModel(f"models/{model}")
response = _model.generate_content(contents=prompt, generation_config=genai.types.GenerationConfig(**inference_params)) response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
)
except Exception as e: except Exception as e:
raise GeminiError( raise GeminiError(
message=str(e), message=str(e),
@ -142,17 +166,22 @@ def completion(
message_obj = Message(content=item.content.parts[0].text) message_obj = Message(content=item.content.parts[0].text)
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(index=idx+1, message=message_obj) choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise GeminiError(message=traceback.format_exc(), status_code=response.status_code) raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
)
try: try:
completion_response = model_response["choices"][0]["message"].get("content") completion_response = model_response["choices"][0]["message"].get("content")
except: except:
raise GeminiError(status_code=400, message=f"No response received. Original response - {response}") raise GeminiError(
status_code=400,
message=f"No response received. Original response - {response}",
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_str = "" prompt_str = ""
@ -164,9 +193,7 @@ def completion(
if content["type"] == "text": if content["type"] == "text":
prompt_str += content["text"] prompt_str += content["text"]
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt_str))
encoding.encode(prompt_str)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -176,11 +203,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -11,32 +11,47 @@ from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper,
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
if request is not None: if request is not None:
self.request = request self.request = request
else: else:
self.request = httpx.Request(method="POST", url="https://api-inference.huggingface.co/models") self.request = httpx.Request(
method="POST", url="https://api-inference.huggingface.co/models"
)
if response is not None: if response is not None:
self.response = response self.response = response
else: else:
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class HuggingfaceConfig():
class HuggingfaceConfig:
""" """
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
""" """
best_of: Optional[int] = None best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of details: Optional[bool] = True # enables returning logprobs + best of
max_new_tokens: Optional[int] = None max_new_tokens: Optional[int] = None
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = False # by default don't return the input as part of the output return_full_text: Optional[
bool
] = False # by default don't return the input as part of the output
seed: Optional[int] = None seed: Optional[int] = None
temperature: Optional[float] = None temperature: Optional[float] = None
top_k: Optional[int] = None top_k: Optional[int] = None
@ -46,7 +61,8 @@ class HuggingfaceConfig():
typical_p: Optional[float] = None typical_p: Optional[float] = None
watermark: Optional[bool] = None watermark: Optional[bool] = None
def __init__(self, def __init__(
self,
best_of: Optional[int] = None, best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None, decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None, details: Optional[bool] = None,
@ -60,19 +76,31 @@ class HuggingfaceConfig():
top_p: Optional[int] = None, top_p: Optional[int] = None,
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: Optional[bool] = None watermark: Optional[bool] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def output_parser(generated_text: str): def output_parser(generated_text: str):
""" """
@ -88,8 +116,11 @@ def output_parser(generated_text: str):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text return generated_text
tgi_models_cache = None tgi_models_cache = None
conv_models_cache = None conv_models_cache = None
def read_tgi_conv_models(): def read_tgi_conv_models():
try: try:
global tgi_models_cache, conv_models_cache global tgi_models_cache, conv_models_cache
@ -101,9 +132,13 @@ def read_tgi_conv_models():
tgi_models = set() tgi_models = set()
script_directory = os.path.dirname(os.path.abspath(__file__)) script_directory = os.path.dirname(os.path.abspath(__file__))
# Construct the file path relative to the script's directory # Construct the file path relative to the script's directory
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_text_generation_models.txt") file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_text_generation_models.txt",
)
with open(file_path, 'r') as file: with open(file_path, "r") as file:
for line in file: for line in file:
tgi_models.add(line.strip()) tgi_models.add(line.strip())
@ -111,9 +146,13 @@ def read_tgi_conv_models():
tgi_models_cache = tgi_models tgi_models_cache = tgi_models
# If not, read the file and populate the cache # If not, read the file and populate the cache
file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_conversational_models.txt") file_path = os.path.join(
script_directory,
"huggingface_llms_metadata",
"hf_conversational_models.txt",
)
conv_models = set() conv_models = set()
with open(file_path, 'r') as file: with open(file_path, "r") as file:
for line in file: for line in file:
conv_models.add(line.strip()) conv_models.add(line.strip())
# Cache the set for future use # Cache the set for future use
@ -136,6 +175,7 @@ def get_hf_task_for_model(model):
else: else:
return "text-generation-inference" # default to tgi return "text-generation-inference" # default to tgi
class Huggingface(BaseLLM): class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None _client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None _aclient_session: Optional[httpx.AsyncClient] = None
@ -148,65 +188,93 @@ class Huggingface(BaseLLM):
"content-type": "application/json", "content-type": "application/json",
} }
if api_key and headers is None: if api_key and headers is None:
default_headers["Authorization"] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens default_headers[
"Authorization"
] = f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
headers = default_headers headers = default_headers
elif headers: elif headers:
headers=headers headers = headers
else: else:
headers = default_headers headers = default_headers
return headers return headers
def convert_to_model_response_object(self, def convert_to_model_response_object(
self,
completion_response, completion_response,
model_response, model_response,
task, task,
optional_params, optional_params,
encoding, encoding,
input_text, input_text,
model): model,
):
if task == "conversational": if task == "conversational":
if len(completion_response["generated_text"]) > 0: # type: ignore if len(completion_response["generated_text"]) > 0: # type: ignore
model_response["choices"][0]["message"][ model_response["choices"][0]["message"][
"content" "content"
] = completion_response["generated_text"] # type: ignore ] = completion_response[
"generated_text"
] # type: ignore
elif task == "text-generation-inference": elif task == "text-generation-inference":
if (not isinstance(completion_response, list) if (
not isinstance(completion_response, list)
or not isinstance(completion_response[0], dict) or not isinstance(completion_response[0], dict)
or "generated_text" not in completion_response[0]): or "generated_text" not in completion_response[0]
raise HuggingfaceError(status_code=422, message=f"response is not in expected format - {completion_response}") ):
raise HuggingfaceError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
)
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = output_parser(
"content" completion_response[0]["generated_text"]
] = output_parser(completion_response[0]["generated_text"]) )
## GETTING LOGPROBS + FINISH REASON ## GETTING LOGPROBS + FINISH REASON
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: if (
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"] "details" in completion_response[0]
and "tokens" in completion_response[0]["details"]
):
model_response.choices[0].finish_reason = completion_response[0][
"details"
]["finish_reason"]
sum_logprob = 0 sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]: for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None: if token["logprob"] != None:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
model_response["choices"][0]["message"]._logprob = sum_logprob model_response["choices"][0]["message"]._logprob = sum_logprob
if "best_of" in optional_params and optional_params["best_of"] > 1: if "best_of" in optional_params and optional_params["best_of"] > 1:
if "details" in completion_response[0] and "best_of_sequences" in completion_response[0]["details"]: if (
"details" in completion_response[0]
and "best_of_sequences" in completion_response[0]["details"]
):
choices_list = [] choices_list = []
for idx, item in enumerate(completion_response[0]["details"]["best_of_sequences"]): for idx, item in enumerate(
completion_response[0]["details"]["best_of_sequences"]
):
sum_logprob = 0 sum_logprob = 0
for token in item["tokens"]: for token in item["tokens"]:
if token["logprob"] != None: if token["logprob"] != None:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0: if len(item["generated_text"]) > 0:
message_obj = Message(content=output_parser(item["generated_text"]), logprobs=sum_logprob) message_obj = Message(
content=output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(finish_reason=item["finish_reason"], index=idx+1, message=message_obj) choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"].extend(choices_list) model_response["choices"].extend(choices_list)
else: else:
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = output_parser(
"content" completion_response[0]["generated_text"]
] = output_parser(completion_response[0]["generated_text"]) )
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = 0 prompt_tokens = 0
try: try:
@ -221,7 +289,9 @@ class Huggingface(BaseLLM):
completion_tokens = 0 completion_tokens = 0
try: try:
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here ) ##[TODO] use the llama2 tokenizer here
except: except:
# this should remain non blocking we should not block a response returning if calculating usage fails # this should remain non blocking we should not block a response returning if calculating usage fails
@ -234,13 +304,14 @@ class Huggingface(BaseLLM):
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
model_response._hidden_params["original_response"] = completion_response model_response._hidden_params["original_response"] = completion_response
return model_response return model_response
def completion(self, def completion(
self,
model: str, model: str,
messages: list, messages: list,
api_base: Optional[str], api_base: Optional[str],
@ -276,9 +347,11 @@ class Huggingface(BaseLLM):
completion_url = f"https://api-inference.huggingface.co/models/{model}" completion_url = f"https://api-inference.huggingface.co/models/{model}"
## Load Config ## Load Config
config=litellm.HuggingfaceConfig.get_config() config = litellm.HuggingfaceConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
### MAP INPUT PARAMS ### MAP INPUT PARAMS
@ -300,9 +373,9 @@ class Huggingface(BaseLLM):
"inputs": { "inputs": {
"text": text, "text": text,
"past_user_inputs": past_user_inputs, "past_user_inputs": past_user_inputs,
"generated_responses": generated_responses "generated_responses": generated_responses,
}, },
"parameters": inference_params "parameters": inference_params,
} }
input_text = "".join(message["content"] for message in messages) input_text = "".join(message["content"] for message in messages)
elif task == "text-generation-inference": elif task == "text-generation-inference":
@ -312,16 +385,22 @@ class Huggingface(BaseLLM):
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", None), role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), initial_prompt_value=model_prompt_details.get(
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), "initial_prompt_value", ""
messages=messages ),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": optional_params, "parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False, "stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
} }
input_text = prompt input_text = prompt
else: else:
@ -332,8 +411,12 @@ class Huggingface(BaseLLM):
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}), role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), initial_prompt_value=model_prompt_details.get(
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), "initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get(
"final_prompt_value", ""
),
bos_token=model_prompt_details.get("bos_token", ""), bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""), eos_token=model_prompt_details.get("eos_token", ""),
messages=messages, messages=messages,
@ -346,14 +429,22 @@ class Huggingface(BaseLLM):
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": inference_params, "parameters": inference_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False, "stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
} }
input_text = prompt input_text = prompt
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input_text, input=input_text,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "task": task, "headers": headers, "api_base": completion_url, "acompletion": acompletion}, additional_args={
"complete_input_dict": data,
"task": task,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
) )
## COMPLETION CALL ## COMPLETION CALL
if acompletion is True: if acompletion is True:
@ -369,29 +460,37 @@ class Huggingface(BaseLLM):
completion_url, completion_url,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream"] stream=optional_params["stream"],
) )
return response.iter_lines() return response.iter_lines()
### SYNC COMPLETION ### SYNC COMPLETION
else: else:
response = requests.post( response = requests.post(
completion_url, completion_url, headers=headers, data=json.dumps(data)
headers=headers,
data=json.dumps(data)
) )
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten) ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
is_streamed = False is_streamed = False
if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream": if (
response.__dict__["headers"].get("Content-Type", "")
== "text/event-stream"
):
is_streamed = True is_streamed = True
# iterate over the complete streamed response, and return the final answer # iterate over the complete streamed response, and return the final answer
if is_streamed: if is_streamed:
streamed_response = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="huggingface", logging_obj=logging_obj) streamed_response = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
content = "" content = ""
for chunk in streamed_response: for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"] content += chunk["choices"][0]["delta"]["content"]
completion_response: List[Dict[str, Any]] = [{"generated_text": content}] completion_response: List[Dict[str, Any]] = [
{"generated_text": content}
]
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input_text, input=input_text,
@ -414,11 +513,16 @@ class Huggingface(BaseLLM):
completion_response = [completion_response] completion_response = [completion_response]
except: except:
import traceback import traceback
raise HuggingfaceError( raise HuggingfaceError(
message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}", status_code=response.status_code message=f"Original Response received: {response.text}; Stacktrace: {traceback.format_exc()}",
status_code=response.status_code,
) )
print_verbose(f"response: {completion_response}") print_verbose(f"response: {completion_response}")
if isinstance(completion_response, dict) and "error" in completion_response: if (
isinstance(completion_response, dict)
and "error" in completion_response
):
print_verbose(f"completion error: {completion_response['error']}") print_verbose(f"completion error: {completion_response['error']}")
print_verbose(f"response.status_code: {response.status_code}") print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError( raise HuggingfaceError(
@ -432,7 +536,7 @@ class Huggingface(BaseLLM):
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
input_text=input_text, input_text=input_text,
model=model model=model,
) )
except HuggingfaceError as e: except HuggingfaceError as e:
exception_mapping_worked = True exception_mapping_worked = True
@ -442,9 +546,11 @@ class Huggingface(BaseLLM):
raise e raise e
else: else:
import traceback import traceback
raise HuggingfaceError(status_code=500, message=traceback.format_exc()) raise HuggingfaceError(status_code=500, message=traceback.format_exc())
async def acompletion(self, async def acompletion(
self,
api_base: str, api_base: str,
data: dict, data: dict,
headers: dict, headers: dict,
@ -453,54 +559,75 @@ class Huggingface(BaseLLM):
encoding: Any, encoding: Any,
input_text: str, input_text: str,
model: str, model: str,
optional_params: dict): optional_params: dict,
):
response = None response = None
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(url=api_base, json=data, headers=headers, timeout=None) response = await client.post(
url=api_base, json=data, headers=headers, timeout=None
)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response) raise HuggingfaceError(
status_code=response.status_code,
message=response.text,
request=response.request,
response=response,
)
## RESPONSE OBJECT ## RESPONSE OBJECT
return self.convert_to_model_response_object(completion_response=response_json, return self.convert_to_model_response_object(
completion_response=response_json,
model_response=model_response, model_response=model_response,
task=task, task=task,
encoding=encoding, encoding=encoding,
input_text=input_text, input_text=input_text,
model=model, model=model,
optional_params=optional_params) optional_params=optional_params,
)
except Exception as e: except Exception as e:
if isinstance(e,httpx.TimeoutException): if isinstance(e, httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error") raise HuggingfaceError(status_code=500, message="Request Timeout Error")
elif response is not None and hasattr(response, "text"): elif response is not None and hasattr(response, "text"):
raise HuggingfaceError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") raise HuggingfaceError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
)
else: else:
raise HuggingfaceError(status_code=500, message=f"{str(e)}") raise HuggingfaceError(status_code=500, message=f"{str(e)}")
async def async_streaming(self, async def async_streaming(
self,
logging_obj, logging_obj,
api_base: str, api_base: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str): model: str,
):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = client.stream( response = client.stream(
"POST", "POST", url=f"{api_base}", json=data, headers=headers
url=f"{api_base}",
json=data,
headers=headers
) )
async with response as r: async with response as r:
if r.status_code != 200: if r.status_code != 200:
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming") raise HuggingfaceError(
status_code=r.status_code,
message="An error occurred while streaming",
)
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(),
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
def embedding(self, def embedding(
self,
model: str, model: str,
input: list, input: list,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -526,29 +653,35 @@ class Huggingface(BaseLLM):
if "sentence-transformers" in model: if "sentence-transformers" in model:
if len(input) == 0: if len(input) == 0:
raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences") raise HuggingfaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = { data = {
"inputs": { "inputs": {
"source_sentence": input[0], "source_sentence": input[0],
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] "sentences": [
"That is a happy dog",
"That is a very happy person",
"Today is a sunny day",
],
} }
} }
else: else:
data = { data = {"inputs": input} # type: ignore
"inputs": input # type: ignore
}
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": embed_url}, additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(embed_url, headers=headers, data=json.dumps(data))
embed_url, headers=headers, data=json.dumps(data)
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -558,11 +691,10 @@ class Huggingface(BaseLLM):
original_response=response, original_response=response,
) )
embeddings = response.json() embeddings = response.json()
if "error" in embeddings: if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings['error']) raise HuggingfaceError(status_code=500, message=embeddings["error"])
output_data = [] output_data = []
if "similarities" in embeddings: if "similarities" in embeddings:
@ -571,7 +703,7 @@ class Huggingface(BaseLLM):
{ {
"object": "embedding", "object": "embedding",
"index": idx, "index": idx,
"embedding": embedding # flatten list returned from hf "embedding": embedding, # flatten list returned from hf
} }
) )
else: else:
@ -581,7 +713,7 @@ class Huggingface(BaseLLM):
{ {
"object": "embedding", "object": "embedding",
"index": idx, "index": idx,
"embedding": embedding # flatten list returned from hf "embedding": embedding, # flatten list returned from hf
} }
) )
elif isinstance(embedding, list) and isinstance(embedding[0], float): elif isinstance(embedding, list) and isinstance(embedding[0], float):
@ -589,7 +721,7 @@ class Huggingface(BaseLLM):
{ {
"object": "embedding", "object": "embedding",
"index": idx, "index": idx,
"embedding": embedding # flatten list returned from hf "embedding": embedding, # flatten list returned from hf
} }
) )
else: else:
@ -597,7 +729,9 @@ class Huggingface(BaseLLM):
{ {
"object": "embedding", "object": "embedding",
"index": idx, "index": idx,
"embedding": embedding[0][0] # flatten list returned from hf "embedding": embedding[0][
0
], # flatten list returned from hf
} }
) )
model_response["object"] = "list" model_response["object"] = "list"
@ -605,13 +739,10 @@ class Huggingface(BaseLLM):
model_response["model"] = model model_response["model"] = model
input_tokens = 0 input_tokens = 0
for text in input: for text in input:
input_tokens+=len(encoding.encode(text)) input_tokens += len(encoding.encode(text))
model_response["usage"] = { model_response["usage"] = {
"prompt_tokens": input_tokens, "prompt_tokens": input_tokens,
"total_tokens": input_tokens, "total_tokens": input_tokens,
} }
return model_response return model_response

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
class MaritalkError(Exception): class MaritalkError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -15,7 +16,8 @@ class MaritalkError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class MaritTalkConfig():
class MaritTalkConfig:
""" """
The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters: The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters:
@ -33,6 +35,7 @@ class MaritTalkConfig():
- `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped. - `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped.
""" """
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
model: Optional[str] = None model: Optional[str] = None
do_sample: Optional[bool] = None do_sample: Optional[bool] = None
@ -41,26 +44,39 @@ class MaritTalkConfig():
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
stopping_tokens: Optional[List[str]] = None stopping_tokens: Optional[List[str]] = None
def __init__(self, def __init__(
max_tokens: Optional[int]=None, self,
max_tokens: Optional[int] = None,
model: Optional[str] = None, model: Optional[str] = None,
do_sample: Optional[bool] = None, do_sample: Optional[bool] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None, repetition_penalty: Optional[float] = None,
stopping_tokens: Optional[List[str]] = None) -> None: stopping_tokens: Optional[List[str]] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
headers = { headers = {
@ -71,6 +87,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Key {api_key}" headers["Authorization"] = f"Key {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -89,9 +106,11 @@ def completion(
model = model model = model
## Load Config ## Load Config
config=litellm.MaritTalkConfig.get_config() config = litellm.MaritTalkConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { data = {
@ -107,7 +126,10 @@ def completion(
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines() return response.iter_lines()
@ -130,15 +152,17 @@ def completion(
else: else:
try: try:
if len(completion_response["answer"]) > 0: if len(completion_response["answer"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response["answer"] model_response["choices"][0]["message"][
"content"
] = completion_response["answer"]
except Exception as e: except Exception as e:
raise MaritalkError(message=response.text, status_code=response.status_code) raise MaritalkError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt = "".join(m["content"] for m in messages) prompt = "".join(m["content"] for m in messages)
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -148,11 +172,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding( def embedding(
model: str, model: str,
input: list, input: list,

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
class NLPCloudError(Exception): class NLPCloudError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -15,7 +16,8 @@ class NLPCloudError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class NLPCloudConfig():
class NLPCloudConfig:
""" """
Reference: https://docs.nlpcloud.com/#generation Reference: https://docs.nlpcloud.com/#generation
@ -43,45 +45,57 @@ class NLPCloudConfig():
- `num_return_sequences` (int): Optional. The number of independently computed returned sequences. - `num_return_sequences` (int): Optional. The number of independently computed returned sequences.
""" """
max_length: Optional[int]=None
length_no_input: Optional[bool]=None
end_sequence: Optional[str]=None
remove_end_sequence: Optional[bool]=None
remove_input: Optional[bool]=None
bad_words: Optional[list]=None
temperature: Optional[float]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
repetition_penalty: Optional[float]=None
num_beams: Optional[int]=None
num_return_sequences: Optional[int]=None
max_length: Optional[int] = None
length_no_input: Optional[bool] = None
end_sequence: Optional[str] = None
remove_end_sequence: Optional[bool] = None
remove_input: Optional[bool] = None
bad_words: Optional[list] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
num_beams: Optional[int] = None
num_return_sequences: Optional[int] = None
def __init__(self, def __init__(
max_length: Optional[int]=None, self,
length_no_input: Optional[bool]=None, max_length: Optional[int] = None,
end_sequence: Optional[str]=None, length_no_input: Optional[bool] = None,
remove_end_sequence: Optional[bool]=None, end_sequence: Optional[str] = None,
remove_input: Optional[bool]=None, remove_end_sequence: Optional[bool] = None,
bad_words: Optional[list]=None, remove_input: Optional[bool] = None,
temperature: Optional[float]=None, bad_words: Optional[list] = None,
top_p: Optional[float]=None, temperature: Optional[float] = None,
top_k: Optional[int]=None, top_p: Optional[float] = None,
repetition_penalty: Optional[float]=None, top_k: Optional[int] = None,
num_beams: Optional[int]=None, repetition_penalty: Optional[float] = None,
num_return_sequences: Optional[int]=None) -> None: num_beams: Optional[int] = None,
num_return_sequences: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
@ -93,6 +107,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Token {api_key}" headers["Authorization"] = f"Token {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -112,7 +127,9 @@ def completion(
## Load Config ## Load Config
config = litellm.NLPCloudConfig.get_config() config = litellm.NLPCloudConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
completion_url_fragment_1 = api_base completion_url_fragment_1 = api_base
@ -131,11 +148,18 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=text, input=text,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": completion_url}, additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return clean_and_iterate_chunks(response) return clean_and_iterate_chunks(response)
@ -161,9 +185,14 @@ def completion(
else: else:
try: try:
if len(completion_response["generated_text"]) > 0: if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response["generated_text"] model_response["choices"][0]["message"][
"content"
] = completion_response["generated_text"]
except: except:
raise NLPCloudError(message=json.dumps(completion_response), status_code=response.status_code) raise NLPCloudError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = completion_response["nb_input_tokens"] prompt_tokens = completion_response["nb_input_tokens"]
@ -174,7 +203,7 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
@ -187,25 +216,27 @@ def completion(
# # Perform further processing based on your needs # # Perform further processing based on your needs
# return cleaned_chunk # return cleaned_chunk
# for line in response.iter_lines(): # for line in response.iter_lines():
# if line: # if line:
# yield process_chunk(line) # yield process_chunk(line)
def clean_and_iterate_chunks(response): def clean_and_iterate_chunks(response):
buffer = b'' buffer = b""
for chunk in response.iter_content(chunk_size=1024): for chunk in response.iter_content(chunk_size=1024):
if not chunk: if not chunk:
break break
buffer += chunk buffer += chunk
while b'\x00' in buffer: while b"\x00" in buffer:
buffer = buffer.replace(b'\x00', b'') buffer = buffer.replace(b"\x00", b"")
yield buffer.decode('utf-8') yield buffer.decode("utf-8")
buffer = b'' buffer = b""
# No more data expected, yield any remaining data in the buffer # No more data expected, yield any remaining data in the buffer
if buffer: if buffer:
yield buffer.decode('utf-8') yield buffer.decode("utf-8")
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls

View file

@ -6,6 +6,7 @@ import litellm
import httpx, aiohttp, asyncio import httpx, aiohttp, asyncio
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class OllamaError(Exception): class OllamaError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -16,7 +17,8 @@ class OllamaError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class OllamaConfig():
class OllamaConfig:
""" """
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
@ -56,51 +58,68 @@ class OllamaConfig():
- `template` (string): the full prompt or prompt template (overrides what is defined in the Modelfile) - `template` (string): the full prompt or prompt template (overrides what is defined in the Modelfile)
""" """
mirostat: Optional[int]=None
mirostat_eta: Optional[float]=None
mirostat_tau: Optional[float]=None
num_ctx: Optional[int]=None
num_gqa: Optional[int]=None
num_thread: Optional[int]=None
repeat_last_n: Optional[int]=None
repeat_penalty: Optional[float]=None
temperature: Optional[float]=None
stop: Optional[list]=None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
tfs_z: Optional[float]=None
num_predict: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
system: Optional[str]=None
template: Optional[str]=None
def __init__(self, mirostat: Optional[int] = None
mirostat: Optional[int]=None, mirostat_eta: Optional[float] = None
mirostat_eta: Optional[float]=None, mirostat_tau: Optional[float] = None
mirostat_tau: Optional[float]=None, num_ctx: Optional[int] = None
num_ctx: Optional[int]=None, num_gqa: Optional[int] = None
num_gqa: Optional[int]=None, num_thread: Optional[int] = None
num_thread: Optional[int]=None, repeat_last_n: Optional[int] = None
repeat_last_n: Optional[int]=None, repeat_penalty: Optional[float] = None
repeat_penalty: Optional[float]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, stop: Optional[
stop: Optional[list]=None, list
tfs_z: Optional[float]=None, ] = None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
num_predict: Optional[int]=None, tfs_z: Optional[float] = None
top_k: Optional[int]=None, num_predict: Optional[int] = None
top_p: Optional[float]=None, top_k: Optional[int] = None
system: Optional[str]=None, top_p: Optional[float] = None
template: Optional[str]=None) -> None: system: Optional[str] = None
template: Optional[str] = None
def __init__(
self,
mirostat: Optional[int] = None,
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,
num_ctx: Optional[int] = None,
num_gqa: Optional[int] = None,
num_thread: Optional[int] = None,
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
system: Optional[str] = None,
template: Optional[str] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# ollama implementation # ollama implementation
def get_ollama_response( def get_ollama_response(
@ -111,36 +130,51 @@ def get_ollama_response(
logging_obj=None, logging_obj=None,
acompletion: bool = False, acompletion: bool = False,
model_response=None, model_response=None,
encoding=None encoding=None,
): ):
if api_base.endswith("/api/generate"): if api_base.endswith("/api/generate"):
url = api_base url = api_base
else: else:
url = f"{api_base}/api/generate" url = f"{api_base}/api/generate"
## Load Config ## Load Config
config=litellm.OllamaConfig.get_config() config = litellm.OllamaConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
optional_params["stream"] = optional_params.get("stream", False) optional_params["stream"] = optional_params.get("stream", False)
data = { data = {"model": model, "prompt": prompt, **optional_params}
"model": model,
"prompt": prompt,
**optional_params
}
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=None, input=None,
api_key=None, api_key=None,
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,}, additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
"acompletion": acompletion,
},
) )
if acompletion is True: if acompletion is True:
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:
response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj) response = ollama_async_streaming(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
else: else:
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj) response = ollama_acompletion(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
return response return response
elif optional_params.get("stream", False) == True: elif optional_params.get("stream", False) == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
@ -168,7 +202,16 @@ def get_ollama_response(
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
if optional_params.get("format", "") == "json": if optional_params.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}]) message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"arguments": response_json["response"], "name": ""},
"type": "function",
}
],
)
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
else: else:
model_response["choices"][0]["message"]["content"] = response_json["response"] model_response["choices"][0]["message"]["content"] = response_json["response"]
@ -176,44 +219,59 @@ def get_ollama_response(
model_response["model"] = "ollama/" + model model_response["model"] = "ollama/" + model
prompt_tokens = response_json["prompt_eval_count"] # type: ignore prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"] completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response return model_response
def ollama_completion_stream(url, data, logging_obj): def ollama_completion_stream(url, data, logging_obj):
with httpx.stream( with httpx.stream(
url=url, url=url, json=data, method="POST", timeout=litellm.request_timeout
json=data,
method="POST",
timeout=litellm.request_timeout
) as response: ) as response:
try: try:
if response.status_code != 200: if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text) raise OllamaError(
status_code=response.status_code, message=response.text
)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj) streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: except Exception as e:
raise e raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
try: try:
client = httpx.AsyncClient() client = httpx.AsyncClient()
async with client.stream( async with client.stream(
url=f"{url}", url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
json=data,
method="POST",
timeout=litellm.request_timeout
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text) raise OllamaError(
status_code=response.status_code, message=response.text
)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj) streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
async def ollama_acompletion(url, data, model_response, encoding, logging_obj): async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
data["stream"] = False data["stream"] = False
try: try:
@ -227,7 +285,7 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=data['prompt'], input=data["prompt"],
api_key="", api_key="",
original_response=resp.text, original_response=resp.text,
additional_args={ additional_args={
@ -240,37 +298,59 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
if data.get("format", "") == "json": if data.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}]) message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": response_json["response"],
"name": "",
},
"type": "function",
}
],
)
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
else: else:
model_response["choices"][0]["message"]["content"] = response_json["response"] model_response["choices"][0]["message"]["content"] = response_json[
"response"
]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data['model'] model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json["prompt_eval_count"] # type: ignore prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"] completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response return model_response
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise e raise e
async def ollama_aembeddings(api_base="http://localhost:11434", async def ollama_aembeddings(
api_base="http://localhost:11434",
model="llama2", model="llama2",
prompt="Why is the sky blue?", prompt="Why is the sky blue?",
optional_params=None, optional_params=None,
logging_obj=None, logging_obj=None,
model_response=None, model_response=None,
encoding=None): encoding=None,
):
if api_base.endswith("/api/embeddings"): if api_base.endswith("/api/embeddings"):
url = api_base url = api_base
else: else:
url = f"{api_base}/api/embeddings" url = f"{api_base}/api/embeddings"
## Load Config ## Load Config
config=litellm.OllamaConfig.get_config() config = litellm.OllamaConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { data = {
@ -308,11 +388,7 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
output_data = [] output_data = []
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
output_data.append( output_data.append(
{ {"object": "embedding", "index": idx, "embedding": embedding}
"object": "embedding",
"index": idx,
"embedding": embedding
}
) )
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = output_data model_response["data"] = output_data

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class OobaboogaError(Exception): class OobaboogaError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -15,6 +16,7 @@ class OobaboogaError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def validate_environment(api_key): def validate_environment(api_key):
headers = { headers = {
"accept": "application/json", "accept": "application/json",
@ -24,6 +26,7 @@ def validate_environment(api_key):
headers["Authorization"] = f"Token {api_key}" headers["Authorization"] = f"Token {api_key}"
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -45,7 +48,10 @@ def completion(
elif api_base: elif api_base:
completion_url = api_base completion_url = api_base
else: else:
raise OobaboogaError(status_code=404, message="API Base not set. Set one via completion(..,api_base='your-api-url')") raise OobaboogaError(
status_code=404,
message="API Base not set. Set one via completion(..,api_base='your-api-url')",
)
model = model model = model
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
@ -54,7 +60,7 @@ def completion(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
@ -72,7 +78,10 @@ def completion(
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines() return response.iter_lines()
@ -89,7 +98,9 @@ def completion(
try: try:
completion_response = response.json() completion_response = response.json()
except: except:
raise OobaboogaError(message=response.text, status_code=response.status_code) raise OobaboogaError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response: if "error" in completion_response:
raise OobaboogaError( raise OobaboogaError(
message=completion_response["error"], message=completion_response["error"],
@ -97,14 +108,17 @@ def completion(
) )
else: else:
try: try:
model_response["choices"][0]["message"]["content"] = completion_response['results'][0]['text'] model_response["choices"][0]["message"][
"content"
] = completion_response["results"][0]["text"]
except: except:
raise OobaboogaError(message=json.dumps(completion_response), status_code=response.status_code) raise OobaboogaError(
message=json.dumps(completion_response),
status_code=response.status_code,
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"]) encoding.encode(model_response["choices"][0]["message"]["content"])
) )
@ -114,11 +128,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -2,15 +2,29 @@ from typing import Optional, Union, Any
import types, time, json import types, time, json
import httpx import httpx
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object, Usage from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
Usage,
)
from typing import Callable, Optional from typing import Callable, Optional
import aiohttp, requests import aiohttp, requests
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
class OpenAIError(Exception): class OpenAIError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None): def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
if request: if request:
@ -20,13 +34,15 @@ class OpenAIError(Exception):
if response: if response:
self.response = response self.response = response
else: else:
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class OpenAIConfig(): class OpenAIConfig:
""" """
Reference: https://platform.openai.com/docs/api-reference/chat/create Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -52,42 +68,56 @@ class OpenAIConfig():
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
""" """
frequency_penalty: Optional[int]=None
function_call: Optional[Union[str, dict]]=None
functions: Optional[list]=None
logit_bias: Optional[dict]=None
max_tokens: Optional[int]=None
n: Optional[int]=None
presence_penalty: Optional[int]=None
stop: Optional[Union[str, list]]=None
temperature: Optional[int]=None
top_p: Optional[int]=None
def __init__(self, frequency_penalty: Optional[int] = None
frequency_penalty: Optional[int]=None, function_call: Optional[Union[str, dict]] = None
function_call: Optional[Union[str, dict]]=None, functions: Optional[list] = None
functions: Optional[list]=None, logit_bias: Optional[dict] = None
logit_bias: Optional[dict]=None, max_tokens: Optional[int] = None
max_tokens: Optional[int]=None, n: Optional[int] = None
n: Optional[int]=None, presence_penalty: Optional[int] = None
presence_penalty: Optional[int]=None, stop: Optional[Union[str, list]] = None
stop: Optional[Union[str, list]]=None, temperature: Optional[int] = None
temperature: Optional[int]=None, top_p: Optional[int] = None
top_p: Optional[int]=None,) -> None:
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class OpenAITextCompletionConfig():
class OpenAITextCompletionConfig:
""" """
Reference: https://platform.openai.com/docs/api-reference/completions/create Reference: https://platform.openai.com/docs/api-reference/completions/create
@ -117,65 +147,80 @@ class OpenAITextCompletionConfig():
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
""" """
best_of: Optional[int]=None
echo: Optional[bool]=None
frequency_penalty: Optional[int]=None
logit_bias: Optional[dict]=None
logprobs: Optional[int]=None
max_tokens: Optional[int]=None
n: Optional[int]=None
presence_penalty: Optional[int]=None
stop: Optional[Union[str, list]]=None
suffix: Optional[str]=None
temperature: Optional[float]=None
top_p: Optional[float]=None
def __init__(self, best_of: Optional[int] = None
best_of: Optional[int]=None, echo: Optional[bool] = None
echo: Optional[bool]=None, frequency_penalty: Optional[int] = None
frequency_penalty: Optional[int]=None, logit_bias: Optional[dict] = None
logit_bias: Optional[dict]=None, logprobs: Optional[int] = None
logprobs: Optional[int]=None, max_tokens: Optional[int] = None
max_tokens: Optional[int]=None, n: Optional[int] = None
n: Optional[int]=None, presence_penalty: Optional[int] = None
presence_penalty: Optional[int]=None, stop: Optional[Union[str, list]] = None
stop: Optional[Union[str, list]]=None, suffix: Optional[str] = None
suffix: Optional[str]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, top_p: Optional[float] = None
top_p: Optional[float]=None) -> None:
def __init__(
self,
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[int] = None,
logit_bias: Optional[dict] = None,
logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class OpenAIChatCompletion(BaseLLM): class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def completion(self, def completion(
self,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: float,
model: Optional[str]=None, model: Optional[str] = None,
messages: Optional[list]=None, messages: Optional[list] = None,
print_verbose: Optional[Callable]=None, print_verbose: Optional[Callable] = None,
api_key: Optional[str]=None, api_key: Optional[str] = None,
api_base: Optional[str]=None, api_base: Optional[str] = None,
acompletion: bool = False, acompletion: bool = False,
logging_obj=None, logging_obj=None,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers: Optional[dict]=None, headers: Optional[dict] = None,
custom_prompt_dict: dict={}, custom_prompt_dict: dict = {},
client=None client=None,
): ):
super().completion() super().completion()
exception_mapping_worked = False exception_mapping_worked = False
@ -186,36 +231,79 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(status_code=422, message=f"Missing model or messages") raise OpenAIError(status_code=422, message=f"Missing model or messages")
if not isinstance(timeout, float): if not isinstance(timeout, float):
raise OpenAIError(status_code=422, message=f"Timeout needs to be a float") raise OpenAIError(
status_code=422, message=f"Timeout needs to be a float"
)
for _ in range(2): # if call fails due to alternating messages, retry with reformatted message for _ in range(
data = { 2
"model": model, ): # if call fails due to alternating messages, retry with reformatted message
"messages": messages, data = {"model": model, "messages": messages, **optional_params}
**optional_params
}
try: try:
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if acompletion is True: if acompletion is True:
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) return self.async_streaming(
logging_obj=logging_obj,
headers=headers,
data=data,
model=model,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
)
else: else:
return self.acompletion(data=data, headers=headers, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) return self.acompletion(
data=data,
headers=headers,
logging_obj=logging_obj,
model_response=model_response,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
)
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) return self.streaming(
logging_obj=logging_obj,
headers=headers,
data=data,
model=model,
api_base=api_base,
api_key=api_key,
timeout=timeout,
client=client,
max_retries=max_retries,
)
else: else:
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
api_key=api_key, api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data}, additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": acompletion,
"complete_input_dict": data,
},
) )
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int") raise OpenAIError(
status_code=422, message="max retries must be an int"
)
if client is None: if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else: else:
openai_client = client openai_client = client
response = openai_client.chat.completions.create(**data) # type: ignore response = openai_client.chat.completions.create(**data) # type: ignore
@ -226,16 +314,23 @@ class OpenAIChatCompletion(BaseLLM):
original_response=stringified_response, original_response=stringified_response,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
)
except Exception as e: except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e): if "Conversation roles must alternate user/assistant" in str(
e
) or "user and assistant roles should be alternating" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = [] new_messages = []
for i in range(len(messages)-1): for i in range(len(messages) - 1):
new_messages.append(messages[i]) new_messages.append(messages[i])
if messages[i]["role"] == messages[i+1]["role"]: if messages[i]["role"] == messages[i + 1]["role"]:
if messages[i]["role"] == "user": if messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""}) new_messages.append(
{"role": "assistant", "content": ""}
)
else: else:
new_messages.append({"role": "user", "content": ""}) new_messages.append({"role": "user", "content": ""})
new_messages.append(messages[-1]) new_messages.append(messages[-1])
@ -252,118 +347,179 @@ class OpenAIChatCompletion(BaseLLM):
except Exception as e: except Exception as e:
raise e raise e
async def acompletion(self, async def acompletion(
self,
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: float,
api_key: Optional[str]=None, api_key: Optional[str] = None,
api_base: Optional[str]=None, api_base: Optional[str] = None,
client=None, client=None,
max_retries=None, max_retries=None,
logging_obj=None, logging_obj=None,
headers=None headers=None,
): ):
response = None response = None
try: try:
if client is None: if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else: else:
openai_aclient = client openai_aclient = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data['messages'], input=data["messages"],
api_key=openai_aclient.api_key, api_key=openai_aclient.api_key,
additional_args={"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, "api_base": openai_aclient._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, additional_args={
"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"},
"api_base": openai_aclient._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
) )
response = await openai_aclient.chat.completions.create(**data) response = await openai_aclient.chat.completions.create(**data)
stringified_response = response.model_dump_json() stringified_response = response.model_dump_json()
logging_obj.post_call( logging_obj.post_call(
input=data['messages'], input=data["messages"],
api_key=api_key, api_key=api_key,
original_response=stringified_response, original_response=stringified_response,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) return convert_to_model_response_object(
response_object=json.loads(stringified_response),
model_response_object=model_response,
)
except Exception as e: except Exception as e:
raise e raise e
def streaming(self, def streaming(
self,
logging_obj, logging_obj,
timeout: float, timeout: float,
data: dict, data: dict,
model: str, model: str,
api_key: Optional[str]=None, api_key: Optional[str] = None,
api_base: Optional[str]=None, api_base: Optional[str] = None,
client = None, client=None,
max_retries=None, max_retries=None,
headers=None headers=None,
): ):
if client is None: if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else: else:
openai_client = client openai_client = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data['messages'], input=data["messages"],
api_key=api_key, api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": False, "complete_input_dict": data}, additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": False,
"complete_input_dict": data,
},
) )
response = openai_client.chat.completions.create(**data) response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
return streamwrapper return streamwrapper
async def async_streaming(self, async def async_streaming(
self,
logging_obj, logging_obj,
timeout: float, timeout: float,
data: dict, data: dict,
model: str, model: str,
api_key: Optional[str]=None, api_key: Optional[str] = None,
api_base: Optional[str]=None, api_base: Optional[str] = None,
client=None, client=None,
max_retries=None, max_retries=None,
headers=None headers=None,
): ):
response = None response = None
try: try:
if client is None: if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else: else:
openai_aclient = client openai_aclient = client
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data['messages'], input=data["messages"],
api_key=api_key, api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data}, additional_args={
"headers": headers,
"api_base": api_base,
"acompletion": True,
"complete_input_dict": data,
},
) )
response = await openai_aclient.chat.completions.create(**data) response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: # need to exception handle here. async exceptions don't get caught in sync functions. except (
Exception
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
if response is not None and hasattr(response, "text"): if response is not None and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") raise OpenAIError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
)
else: else:
if type(e).__name__ == "ReadTimeout": if type(e).__name__ == "ReadTimeout":
raise OpenAIError(status_code=408, message=f"{type(e).__name__}") raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
else: else:
raise OpenAIError(status_code=500, message=f"{str(e)}") raise OpenAIError(status_code=500, message=f"{str(e)}")
async def aembedding( async def aembedding(
self, self,
input: list, input: list,
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: float,
api_key: Optional[str]=None, api_key: Optional[str] = None,
api_base: Optional[str]=None, api_base: Optional[str] = None,
client=None, client=None,
max_retries=None, max_retries=None,
logging_obj=None logging_obj=None,
): ):
response = None response = None
try: try:
if client is None: if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.embeddings.create(**data) # type: ignore response = await openai_aclient.embeddings.create(**data) # type: ignore
@ -385,7 +541,8 @@ class OpenAIChatCompletion(BaseLLM):
) )
raise e raise e
def embedding(self, def embedding(
self,
model: str, model: str,
input: list, input: list,
timeout: float, timeout: float,
@ -401,11 +558,7 @@ class OpenAIChatCompletion(BaseLLM):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
model = model model = model
data = { data = {"model": model, "input": input, **optional_params}
"model": model,
"input": input,
**optional_params
}
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int") raise OpenAIError(status_code=422, message="max retries must be an int")
@ -420,7 +573,13 @@ class OpenAIChatCompletion(BaseLLM):
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response return response
if client is None: if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else: else:
openai_client = client openai_client = client
@ -443,6 +602,7 @@ class OpenAIChatCompletion(BaseLLM):
raise e raise e
else: else:
import traceback import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc()) raise OpenAIError(status_code=500, message=traceback.format_exc())
async def aimage_generation( async def aimage_generation(
@ -451,16 +611,22 @@ class OpenAIChatCompletion(BaseLLM):
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: float,
api_key: Optional[str]=None, api_key: Optional[str] = None,
api_base: Optional[str]=None, api_base: Optional[str] = None,
client=None, client=None,
max_retries=None, max_retries=None,
logging_obj=None logging_obj=None,
): ):
response = None response = None
try: try:
if client is None: if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
)
else: else:
openai_aclient = client openai_aclient = client
response = await openai_aclient.images.generate(**data) # type: ignore response = await openai_aclient.images.generate(**data) # type: ignore
@ -482,7 +648,8 @@ class OpenAIChatCompletion(BaseLLM):
) )
raise e raise e
def image_generation(self, def image_generation(
self,
model: Optional[str], model: Optional[str],
prompt: str, prompt: str,
timeout: float, timeout: float,
@ -497,11 +664,7 @@ class OpenAIChatCompletion(BaseLLM):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
model = model model = model
data = { data = {"model": model, "prompt": prompt, **optional_params}
"model": model,
"prompt": prompt,
**optional_params
}
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int") raise OpenAIError(status_code=422, message="max retries must be an int")
@ -511,7 +674,13 @@ class OpenAIChatCompletion(BaseLLM):
# return response # return response
if client is None: if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
)
else: else:
openai_client = client openai_client = client
@ -519,7 +688,12 @@ class OpenAIChatCompletion(BaseLLM):
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=openai_client.api_key, api_key=openai_client.api_key,
additional_args={"headers": {"Authorization": f"Bearer {openai_client.api_key}"}, "api_base": openai_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, additional_args={
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
"api_base": openai_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
) )
## COMPLETION CALL ## COMPLETION CALL
@ -541,8 +715,10 @@ class OpenAIChatCompletion(BaseLLM):
raise e raise e
else: else:
import traceback import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc()) raise OpenAIError(status_code=500, message=traceback.format_exc())
class OpenAITextCompletion(BaseLLM): class OpenAITextCompletion(BaseLLM):
_client_session: httpx.Client _client_session: httpx.Client
@ -558,15 +734,21 @@ class OpenAITextCompletion(BaseLLM):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
def convert_to_model_response_object(self, response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None): def convert_to_model_response_object(
self,
response_object: Optional[dict] = None,
model_response_object: Optional[ModelResponse] = None,
):
try: try:
## RESPONSE OBJECT ## RESPONSE OBJECT
if response_object is None or model_response_object is None: if response_object is None or model_response_object is None:
raise ValueError("Error in response object format") raise ValueError("Error in response object format")
choice_list=[] choice_list = []
for idx, choice in enumerate(response_object["choices"]): for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["text"], role="assistant") message = Message(content=choice["text"], role="assistant")
choice = Choices(finish_reason=choice["finish_reason"], index=idx, message=message) choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
)
choice_list.append(choice) choice_list.append(choice)
model_response_object.choices = choice_list model_response_object.choices = choice_list
@ -579,24 +761,28 @@ class OpenAITextCompletion(BaseLLM):
if "model" in response_object: if "model" in response_object:
model_response_object.model = response_object["model"] model_response_object.model = response_object["model"]
model_response_object._hidden_params["original_response"] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response model_response_object._hidden_params[
"original_response"
] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
return model_response_object return model_response_object
except Exception as e: except Exception as e:
raise e raise e
def completion(self, def completion(
self,
model_response: ModelResponse, model_response: ModelResponse,
api_key: str, api_key: str,
model: str, model: str,
messages: list, messages: list,
print_verbose: Optional[Callable]=None, print_verbose: Optional[Callable] = None,
api_base: Optional[str]=None, api_base: Optional[str] = None,
logging_obj=None, logging_obj=None,
acompletion: bool = False, acompletion: bool = False,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers: Optional[dict]=None): headers: Optional[dict] = None,
):
super().completion() super().completion()
exception_mapping_worked = False exception_mapping_worked = False
try: try:
@ -607,7 +793,11 @@ class OpenAITextCompletion(BaseLLM):
api_base = f"{api_base}/completions" api_base = f"{api_base}/completions"
if len(messages)>0 and "content" in messages[0] and type(messages[0]["content"]) == list: if (
len(messages) > 0
and "content" in messages[0]
and type(messages[0]["content"]) == list
):
prompt = messages[0]["content"] prompt = messages[0]["content"]
else: else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore prompt = " ".join([message["content"] for message in messages]) # type: ignore
@ -615,24 +805,38 @@ class OpenAITextCompletion(BaseLLM):
# don't send max retries to the api, if set # don't send max retries to the api, if set
optional_params.pop("max_retries", None) optional_params.pop("max_retries", None)
data = { data = {"model": model, "prompt": prompt, **optional_params}
"model": model,
"prompt": prompt,
**optional_params
}
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,
api_key=api_key, api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "complete_input_dict": data}, additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
) )
if acompletion == True: if acompletion == True:
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
headers=headers,
model_response=model_response,
model=model,
)
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
headers=headers,
model_response=model_response,
model=model,
)
else: else:
response = httpx.post( response = httpx.post(
url=f"{api_base}", url=f"{api_base}",
@ -640,7 +844,9 @@ class OpenAITextCompletion(BaseLLM):
headers=headers, headers=headers,
) )
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text) raise OpenAIError(
status_code=response.status_code, message=response.text
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -654,11 +860,15 @@ class OpenAITextCompletion(BaseLLM):
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response) return self.convert_to_model_response_object(
response_object=response.json(),
model_response_object=model_response,
)
except Exception as e: except Exception as e:
raise e raise e
async def acompletion(self, async def acompletion(
self,
logging_obj, logging_obj,
api_base: str, api_base: str,
data: dict, data: dict,
@ -666,13 +876,21 @@ class OpenAITextCompletion(BaseLLM):
model_response: ModelResponse, model_response: ModelResponse,
prompt: str, prompt: str,
api_key: str, api_key: str,
model: str): model: str,
):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) response = await client.post(
api_base,
json=data,
headers=headers,
timeout=litellm.request_timeout,
)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text) raise OpenAIError(
status_code=response.status_code, message=response.text
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -686,52 +904,71 @@ class OpenAITextCompletion(BaseLLM):
) )
## RESPONSE OBJECT ## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response) return self.convert_to_model_response_object(
response_object=response_json, model_response_object=model_response
)
except Exception as e: except Exception as e:
raise e raise e
def streaming(self, def streaming(
self,
logging_obj, logging_obj,
api_base: str, api_base: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str model: str,
): ):
with httpx.stream( with httpx.stream(
url=f"{api_base}", url=f"{api_base}",
json=data, json=data,
headers=headers, headers=headers,
method="POST", method="POST",
timeout=litellm.request_timeout timeout=litellm.request_timeout,
) as response: ) as response:
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text) raise OpenAIError(
status_code=response.status_code, message=response.text
)
streamwrapper = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
async def async_streaming(self, async def async_streaming(
self,
logging_obj, logging_obj,
api_base: str, api_base: str,
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str): model: str,
):
client = httpx.AsyncClient() client = httpx.AsyncClient()
async with client.stream( async with client.stream(
url=f"{api_base}", url=f"{api_base}",
json=data, json=data,
headers=headers, headers=headers,
method="POST", method="POST",
timeout=litellm.request_timeout timeout=litellm.request_timeout,
) as response: ) as response:
try: try:
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text) raise OpenAIError(
status_code=response.status_code, message=response.text
)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: except Exception as e:

View file

@ -1,30 +1,41 @@
from typing import List, Dict from typing import List, Dict
import types import types
class OpenrouterConfig():
class OpenrouterConfig:
""" """
Reference: https://openrouter.ai/docs#format Reference: https://openrouter.ai/docs#format
""" """
# OpenRouter-only parameters # OpenRouter-only parameters
extra_body: Dict[str, List[str]] = { extra_body: Dict[str, List[str]] = {"transforms": []} # default transforms to []
'transforms': [] # default transforms to []
}
def __init__(
def __init__(self, self,
transforms: List[str] = [], transforms: List[str] = [],
models: List[str] = [], models: List[str] = [],
route: str = '', route: str = "",
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}

View file

@ -7,17 +7,22 @@ from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
import litellm import litellm
import sys, httpx import sys, httpx
class PalmError(Exception): class PalmError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat") self.request = httpx.Request(
method="POST",
url="https://developers.generativeai.google/api/python/google/generativeai/chat",
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class PalmConfig():
class PalmConfig:
""" """
Reference: https://developers.generativeai.google/api/python/google/generativeai/chat Reference: https://developers.generativeai.google/api/python/google/generativeai/chat
@ -37,35 +42,47 @@ class PalmConfig():
- `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output - `max_output_tokens` (int): Sets the maximum number of tokens to be returned in the output
""" """
context: Optional[str]=None
examples: Optional[list]=None
temperature: Optional[float]=None
candidate_count: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
max_output_tokens: Optional[int]=None
def __init__(self, context: Optional[str] = None
context: Optional[str]=None, examples: Optional[list] = None
examples: Optional[list]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, candidate_count: Optional[int] = None
candidate_count: Optional[int]=None, top_k: Optional[int] = None
top_k: Optional[int]=None, top_p: Optional[float] = None
top_p: Optional[float]=None, max_output_tokens: Optional[int] = None
max_output_tokens: Optional[int]=None) -> None:
def __init__(
self,
context: Optional[str] = None,
examples: Optional[list] = None,
temperature: Optional[float] = None,
candidate_count: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
max_output_tokens: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion( def completion(
@ -83,30 +100,32 @@ def completion(
try: try:
import google.generativeai as palm import google.generativeai as palm
except: except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
palm.configure(api_key=api_key) palm.configure(api_key=api_key)
model = model model = model
## Load Config ## Load Config
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py inference_params.pop(
"stream", None
) # palm does not support streaming, so we handle this by fake streaming in main.py
config = litellm.PalmConfig.get_config() config = litellm.PalmConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
prompt = "" prompt = ""
for message in messages: for message in messages:
if "role" in message: if "role" in message:
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
@ -142,22 +161,25 @@ def completion(
message_obj = Message(content=item["output"]) message_obj = Message(content=item["output"])
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices(index=idx+1, message=message_obj) choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise PalmError(message=traceback.format_exc(), status_code=response.status_code) raise PalmError(
message=traceback.format_exc(), status_code=response.status_code
)
try: try:
completion_response = model_response["choices"][0]["message"].get("content") completion_response = model_response["choices"][0]["message"].get("content")
except: except:
raise PalmError(status_code=400, message=f"No response received. Original response - {response}") raise PalmError(
status_code=400,
message=f"No response received. Original response - {response}",
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -167,11 +189,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -8,6 +8,7 @@ import litellm
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class PetalsError(Exception): class PetalsError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -16,7 +17,8 @@ class PetalsError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class PetalsConfig():
class PetalsConfig:
""" """
Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below: The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
@ -37,33 +39,52 @@ class PetalsConfig():
- `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper. - `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
""" """
max_length: Optional[int]=None
max_new_tokens: Optional[int]=litellm.max_tokens # petals requires max tokens to be set
do_sample: Optional[bool]=None
temperature: Optional[float]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
repetition_penalty: Optional[float]=None
def __init__(self, max_length: Optional[int] = None
max_length: Optional[int]=None, max_new_tokens: Optional[
max_new_tokens: Optional[int]=litellm.max_tokens, # petals requires max tokens to be set int
do_sample: Optional[bool]=None, ] = litellm.max_tokens # petals requires max tokens to be set
temperature: Optional[float]=None, do_sample: Optional[bool] = None
top_k: Optional[int]=None, temperature: Optional[float] = None
top_p: Optional[float]=None, top_k: Optional[int] = None
repetition_penalty: Optional[float]=None) -> None: top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
def __init__(
self,
max_length: Optional[int] = None,
max_new_tokens: Optional[
int
] = litellm.max_tokens, # petals requires max tokens to be set
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def completion( def completion(
model: str, model: str,
@ -81,7 +102,9 @@ def completion(
## Load Config ## Load Config
config = litellm.PetalsConfig.get_config() config = litellm.PetalsConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
if model in litellm.custom_prompt_dict: if model in litellm.custom_prompt_dict:
@ -91,7 +114,7 @@ def completion(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
@ -101,13 +124,12 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": optional_params, "api_base": api_base}, additional_args={
"complete_input_dict": optional_params,
"api_base": api_base,
},
) )
data = { data = {"model": model, "inputs": prompt, **optional_params}
"model": model,
"inputs": prompt,
**optional_params
}
## COMPLETION CALL ## COMPLETION CALL
response = requests.post(api_base, data=data) response = requests.post(api_base, data=data)
@ -138,7 +160,9 @@ def completion(
model = model model = model
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, add_bos_token=False) tokenizer = AutoTokenizer.from_pretrained(
model, use_fast=False, add_bos_token=False
)
model_obj = AutoDistributedModelForCausalLM.from_pretrained(model) model_obj = AutoDistributedModelForCausalLM.from_pretrained(model)
## LOGGING ## LOGGING
@ -167,9 +191,7 @@ def completion(
if len(output_text) > 0: if len(output_text) > 0:
model_response["choices"][0]["message"]["content"] = output_text model_response["choices"][0]["message"]["content"] = output_text
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content")) encoding.encode(model_response["choices"][0]["message"].get("content"))
) )
@ -179,11 +201,12 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -4,9 +4,11 @@ import json
from jinja2 import Template, exceptions, Environment, meta from jinja2 import Template, exceptions, Environment, meta
from typing import Optional, Any from typing import Optional, Any
def default_pt(messages): def default_pt(messages):
return " ".join(message["content"] for message in messages) return " ".join(message["content"] for message in messages)
# alpaca prompt template - for models like mythomax, etc. # alpaca prompt template - for models like mythomax, etc.
def alpaca_pt(messages): def alpaca_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(
@ -19,48 +21,45 @@ def alpaca_pt(messages):
"pre_message": "### Instruction:\n", "pre_message": "### Instruction:\n",
"post_message": "\n\n", "post_message": "\n\n",
}, },
"assistant": { "assistant": {"pre_message": "### Response:\n", "post_message": "\n\n"},
"pre_message": "### Response:\n",
"post_message": "\n\n"
}
}, },
bos_token="<s>", bos_token="<s>",
eos_token="</s>", eos_token="</s>",
messages=messages messages=messages,
) )
return prompt return prompt
# Llama2 prompt template # Llama2 prompt template
def llama_2_chat_pt(messages): def llama_2_chat_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(
role_dict={ role_dict={
"system": { "system": {
"pre_message": "[INST] <<SYS>>\n", "pre_message": "[INST] <<SYS>>\n",
"post_message": "\n<</SYS>>\n [/INST]\n" "post_message": "\n<</SYS>>\n [/INST]\n",
}, },
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348 "user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
"pre_message": "[INST] ", "pre_message": "[INST] ",
"post_message": " [/INST]\n" "post_message": " [/INST]\n",
}, },
"assistant": { "assistant": {
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama "post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
} },
}, },
messages=messages, messages=messages,
bos_token="<s>", bos_token="<s>",
eos_token="</s>" eos_token="</s>",
) )
return prompt return prompt
def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
def ollama_pt(
model, messages
): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
if "instruct" in model: if "instruct" in model:
prompt = custom_prompt( prompt = custom_prompt(
role_dict={ role_dict={
"system": { "system": {"pre_message": "### System:\n", "post_message": "\n"},
"pre_message": "### System:\n",
"post_message": "\n"
},
"user": { "user": {
"pre_message": "### User:\n", "pre_message": "### User:\n",
"post_message": "\n", "post_message": "\n",
@ -68,10 +67,10 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
"assistant": { "assistant": {
"pre_message": "### Response:\n", "pre_message": "### Response:\n",
"post_message": "\n", "post_message": "\n",
} },
}, },
final_prompt_value="### Response:", final_prompt_value="### Response:",
messages=messages messages=messages,
) )
elif "llava" in model: elif "llava" in model:
prompt = "" prompt = ""
@ -88,36 +87,31 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
elif element["type"] == "image_url": elif element["type"] == "image_url":
image_url = element["image_url"]["url"] image_url = element["image_url"]["url"]
images.append(image_url) images.append(image_url)
return { return {"prompt": prompt, "images": images}
"prompt": prompt,
"images": images
}
else: else:
prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages) prompt = "".join(
m["content"]
if isinstance(m["content"], str) is str
else "".join(m["content"])
for m in messages
)
return prompt return prompt
def mistral_instruct_pt(messages): def mistral_instruct_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(
initial_prompt_value="<s>", initial_prompt_value="<s>",
role_dict={ role_dict={
"system": { "system": {"pre_message": "[INST]", "post_message": "[/INST]"},
"pre_message": "[INST]", "user": {"pre_message": "[INST]", "post_message": "[/INST]"},
"post_message": "[/INST]" "assistant": {"pre_message": "[INST]", "post_message": "[/INST]"},
},
"user": {
"pre_message": "[INST]",
"post_message": "[/INST]"
},
"assistant": {
"pre_message": "[INST]",
"post_message": "[/INST]"
}
}, },
final_prompt_value="</s>", final_prompt_value="</s>",
messages=messages messages=messages,
) )
return prompt return prompt
# Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 # Falcon prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def falcon_instruct_pt(messages): def falcon_instruct_pt(messages):
prompt = "" prompt = ""
@ -125,11 +119,16 @@ def falcon_instruct_pt(messages):
if message["role"] == "system": if message["role"] == "system":
prompt += message["content"] prompt += message["content"]
else: else:
prompt += message['role']+":"+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n") prompt += (
message["role"]
+ ":"
+ message["content"].replace("\r\n", "\n").replace("\n\n", "\n")
)
prompt += "\n\n" prompt += "\n\n"
return prompt return prompt
def falcon_chat_pt(messages): def falcon_chat_pt(messages):
prompt = "" prompt = ""
for message in messages: for message in messages:
@ -142,6 +141,7 @@ def falcon_chat_pt(messages):
return prompt return prompt
# MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110 # MPT prompt template - from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py#L110
def mpt_chat_pt(messages): def mpt_chat_pt(messages):
prompt = "" prompt = ""
@ -154,6 +154,7 @@ def mpt_chat_pt(messages):
prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n" prompt += "<|im_start|>user" + message["content"] + "<|im_end|>" + "\n"
return prompt return prompt
# WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format # WizardCoder prompt template - https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0#prompt-format
def wizardcoder_pt(messages): def wizardcoder_pt(messages):
prompt = "" prompt = ""
@ -166,6 +167,7 @@ def wizardcoder_pt(messages):
prompt += "### Response:\n" + message["content"] + "\n\n" prompt += "### Response:\n" + message["content"] + "\n\n"
return prompt return prompt
# Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model # Phind-CodeLlama prompt template - https://huggingface.co/Phind/Phind-CodeLlama-34B-v2#how-to-prompt-the-model
def phind_codellama_pt(messages): def phind_codellama_pt(messages):
prompt = "" prompt = ""
@ -178,13 +180,17 @@ def phind_codellama_pt(messages):
prompt += "### Assistant\n" + message["content"] + "\n\n" prompt += "### Assistant\n" + message["content"] + "\n\n"
return prompt return prompt
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] = None):
## get the tokenizer config from huggingface ## get the tokenizer config from huggingface
bos_token = "" bos_token = ""
eos_token = "" eos_token = ""
if chat_template is None: if chat_template is None:
def _get_tokenizer_config(hf_model_name): def _get_tokenizer_config(hf_model_name):
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" url = (
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
)
# Make a GET request to fetch the JSON data # Make a GET request to fetch the JSON data
response = requests.get(url) response = requests.get(url)
if response.status_code == 200: if response.status_code == 200:
@ -193,8 +199,12 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
return {"status": "success", "tokenizer": tokenizer_config} return {"status": "success", "tokenizer": tokenizer_config}
else: else:
return {"status": "failure"} return {"status": "failure"}
tokenizer_config = _get_tokenizer_config(model) tokenizer_config = _get_tokenizer_config(model)
if tokenizer_config["status"] == "failure" or "chat_template" not in tokenizer_config["tokenizer"]: if (
tokenizer_config["status"] == "failure"
or "chat_template" not in tokenizer_config["tokenizer"]
):
raise Exception("No chat template found") raise Exception("No chat template found")
## read the bos token, eos token and chat template from the json ## read the bos token, eos token and chat template from the json
tokenizer_config = tokenizer_config["tokenizer"] tokenizer_config = tokenizer_config["tokenizer"]
@ -207,7 +217,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
# Create a template object from the template text # Create a template object from the template text
env = Environment() env = Environment()
env.globals['raise_exception'] = raise_exception env.globals["raise_exception"] = raise_exception
try: try:
template = env.from_string(chat_template) template = env.from_string(chat_template)
except Exception as e: except Exception as e:
@ -216,7 +226,11 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
def _is_system_in_template(): def _is_system_in_template():
try: try:
# Try rendering the template with a system message # Try rendering the template with a system message
response = template.render(messages=[{"role": "system", "content": "test"}], eos_token= "<eos>", bos_token= "<bos>") response = template.render(
messages=[{"role": "system", "content": "test"}],
eos_token="<eos>",
bos_token="<bos>",
)
return True return True
# This will be raised if Jinja attempts to render the system message and it can't # This will be raised if Jinja attempts to render the system message and it can't
@ -226,36 +240,54 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
try: try:
# Render the template with the provided values # Render the template with the provided values
if _is_system_in_template(): if _is_system_in_template():
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=messages) rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=messages
)
else: else:
# treat a system message as a user message, if system not in template # treat a system message as a user message, if system not in template
try: try:
reformatted_messages = [] reformatted_messages = []
for message in messages: for message in messages:
if message["role"] == "system": if message["role"] == "system":
reformatted_messages.append({"role": "user", "content": message["content"]}) reformatted_messages.append(
{"role": "user", "content": message["content"]}
)
else: else:
reformatted_messages.append(message) reformatted_messages.append(message)
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=reformatted_messages) rendered_text = template.render(
bos_token=bos_token,
eos_token=eos_token,
messages=reformatted_messages,
)
except Exception as e: except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e): if "Conversation roles must alternate user/assistant" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = [] new_messages = []
for i in range(len(reformatted_messages)-1): for i in range(len(reformatted_messages) - 1):
new_messages.append(reformatted_messages[i]) new_messages.append(reformatted_messages[i])
if reformatted_messages[i]["role"] == reformatted_messages[i+1]["role"]: if (
reformatted_messages[i]["role"]
== reformatted_messages[i + 1]["role"]
):
if reformatted_messages[i]["role"] == "user": if reformatted_messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""}) new_messages.append(
{"role": "assistant", "content": ""}
)
else: else:
new_messages.append({"role": "user", "content": ""}) new_messages.append({"role": "user", "content": ""})
new_messages.append(reformatted_messages[-1]) new_messages.append(reformatted_messages[-1])
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages) rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=new_messages
)
return rendered_text return rendered_text
except Exception as e: except Exception as e:
raise Exception(f"Error rendering template - {str(e)}") raise Exception(f"Error rendering template - {str(e)}")
# Anthropic template # Anthropic template
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts def claude_2_1_pt(
messages: list,
): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
""" """
Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human: Claude v2.1 allows system prompts (no Human: needed), but requires it be followed by Human:
- you can't just pass a system message - you can't just pass a system message
@ -264,6 +296,7 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
if a system message is passed in and followed by an assistant message, insert a blank human message between them. if a system message is passed in and followed by an assistant message, insert a blank human message between them.
""" """
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: " AI_PROMPT = "\n\nAssistant: "
@ -271,81 +304,88 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
prompt = "" prompt = ""
for idx, message in enumerate(messages): for idx, message in enumerate(messages):
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
)
elif message["role"] == "system": elif message["role"] == "system":
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
elif message["role"] == "assistant": elif message["role"] == "assistant":
if idx > 0 and messages[idx - 1]["role"] == "system": if idx > 0 and messages[idx - 1]["role"] == "system":
prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}" # Insert a blank human message
prompt += ( prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
)
prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn prompt += f"{AnthropicConstants.AI_PROMPT.value}" # prompt must end with \"\n\nAssistant: " turn
return prompt return prompt
### TOGETHER AI ### TOGETHER AI
def get_model_info(token, model): def get_model_info(token, model):
try: try:
headers = { headers = {"Authorization": f"Bearer {token}"}
'Authorization': f'Bearer {token}' response = requests.get("https://api.together.xyz/models/info", headers=headers)
}
response = requests.get('https://api.together.xyz/models/info', headers=headers)
if response.status_code == 200: if response.status_code == 200:
model_info = response.json() model_info = response.json()
for m in model_info: for m in model_info:
if m["name"].lower().strip() == model.strip(): if m["name"].lower().strip() == model.strip():
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None) return m["config"].get("prompt_format", None), m["config"].get(
"chat_template", None
)
return None, None return None, None
else: else:
return None, None return None, None
except Exception as e: # safely fail a prompt template request except Exception as e: # safely fail a prompt template request
return None, None return None, None
def format_prompt_togetherai(messages, prompt_format, chat_template): def format_prompt_togetherai(messages, prompt_format, chat_template):
if prompt_format is None: if prompt_format is None:
return default_pt(messages) return default_pt(messages)
human_prompt, assistant_prompt = prompt_format.split('{prompt}') human_prompt, assistant_prompt = prompt_format.split("{prompt}")
if chat_template is not None: if chat_template is not None:
prompt = hf_chat_template(model=None, messages=messages, chat_template=chat_template) prompt = hf_chat_template(
model=None, messages=messages, chat_template=chat_template
)
elif prompt_format is not None: elif prompt_format is not None:
prompt = custom_prompt(role_dict={}, messages=messages, initial_prompt_value=human_prompt, final_prompt_value=assistant_prompt) prompt = custom_prompt(
role_dict={},
messages=messages,
initial_prompt_value=human_prompt,
final_prompt_value=assistant_prompt,
)
else: else:
prompt = default_pt(messages) prompt = default_pt(messages)
return prompt return prompt
### ###
def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/reference/complete_post
def anthropic_pt(
messages: list,
): # format - https://docs.anthropic.com/claude/reference/complete_post
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: " AI_PROMPT = "\n\nAssistant: "
prompt = "" prompt = ""
for idx, message in enumerate(messages): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: ` for idx, message in enumerate(
messages
): # needs to start with `\n\nHuman: ` and end with `\n\nAssistant: `
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
f"{AnthropicConstants.HUMAN_PROMPT.value}{message['content']}"
)
elif message["role"] == "system": elif message["role"] == "system":
prompt += ( prompt += f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
f"{AnthropicConstants.HUMAN_PROMPT.value}<admin>{message['content']}</admin>"
)
else: else:
prompt += ( prompt += f"{AnthropicConstants.AI_PROMPT.value}{message['content']}"
f"{AnthropicConstants.AI_PROMPT.value}{message['content']}" if (
) idx == 0 and message["role"] == "assistant"
if idx == 0 and message["role"] == "assistant": # ensure the prompt always starts with `\n\nHuman: ` ): # ensure the prompt always starts with `\n\nHuman: `
prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt prompt = f"{AnthropicConstants.HUMAN_PROMPT.value}" + prompt
prompt += f"{AnthropicConstants.AI_PROMPT.value}" prompt += f"{AnthropicConstants.AI_PROMPT.value}"
return prompt return prompt
def gemini_text_image_pt(messages: list): def gemini_text_image_pt(messages: list):
""" """
{ {
@ -367,7 +407,9 @@ def gemini_text_image_pt(messages: list):
try: try:
import google.generativeai as genai import google.generativeai as genai
except: except:
raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
prompt = "" prompt = ""
images = [] images = []
@ -387,26 +429,36 @@ def gemini_text_image_pt(messages: list):
content = [prompt] + images content = [prompt] + images
return content return content
# Function call template # Function call template
def function_call_prompt(messages: list, functions: list): def function_call_prompt(messages: list, functions: list):
function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:" function_prompt = (
"Produce JSON OUTPUT ONLY! The following functions are available to you:"
)
for function in functions: for function in functions:
function_prompt += f"""\n{function}\n""" function_prompt += f"""\n{function}\n"""
function_added_to_prompt = False function_added_to_prompt = False
for message in messages: for message in messages:
if "system" in message["role"]: if "system" in message["role"]:
message['content'] += f"""{function_prompt}""" message["content"] += f"""{function_prompt}"""
function_added_to_prompt = True function_added_to_prompt = True
if function_added_to_prompt == False: if function_added_to_prompt == False:
messages.append({'role': 'system', 'content': f"""{function_prompt}"""}) messages.append({"role": "system", "content": f"""{function_prompt}"""})
return messages return messages
# Custom prompt template # Custom prompt template
def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", final_prompt_value: str="", bos_token: str="", eos_token: str=""): def custom_prompt(
role_dict: dict,
messages: list,
initial_prompt_value: str = "",
final_prompt_value: str = "",
bos_token: str = "",
eos_token: str = "",
):
prompt = bos_token + initial_prompt_value prompt = bos_token + initial_prompt_value
bos_open = True bos_open = True
## a bos token is at the start of a system / human message ## a bos token is at the start of a system / human message
@ -418,8 +470,16 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += bos_token prompt += bos_token
bos_open = True bos_open = True
pre_message_str = role_dict[role]["pre_message"] if role in role_dict and "pre_message" in role_dict[role] else "" pre_message_str = (
post_message_str = role_dict[role]["post_message"] if role in role_dict and "post_message" in role_dict[role] else "" role_dict[role]["pre_message"]
if role in role_dict and "pre_message" in role_dict[role]
else ""
)
post_message_str = (
role_dict[role]["post_message"]
if role in role_dict and "post_message" in role_dict[role]
else ""
)
prompt += pre_message_str + message["content"] + post_message_str prompt += pre_message_str + message["content"] + post_message_str
if role == "assistant": if role == "assistant":
@ -429,7 +489,13 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += final_prompt_value prompt += final_prompt_value
return prompt return prompt
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None, api_key: Optional[str]=None):
def prompt_factory(
model: str,
messages: list,
custom_llm_provider: Optional[str] = None,
api_key: Optional[str] = None,
):
original_model_name = model original_model_name = model
model = model.lower() model = model.lower()
if custom_llm_provider == "ollama": if custom_llm_provider == "ollama":
@ -441,13 +507,17 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return anthropic_pt(messages=messages) return anthropic_pt(messages=messages)
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model) prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template) return format_prompt_togetherai(
messages=messages, prompt_format=prompt_format, chat_template=chat_template
)
elif custom_llm_provider == "gemini": elif custom_llm_provider == "gemini":
return gemini_text_image_pt(messages=messages) return gemini_text_image_pt(messages=messages)
try: try:
if "meta-llama/llama-2" in model and "chat" in model: if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages) return llama_2_chat_pt(messages=messages)
elif "tiiuae/falcon" in model: # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template. elif (
"tiiuae/falcon" in model
): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.
if model == "tiiuae/falcon-180B-chat": if model == "tiiuae/falcon-180B-chat":
return falcon_chat_pt(messages=messages) return falcon_chat_pt(messages=messages)
elif "instruct" in model: elif "instruct" in model:
@ -457,17 +527,26 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return mpt_chat_pt(messages=messages) return mpt_chat_pt(messages=messages)
elif "codellama/codellama" in model or "togethercomputer/codellama" in model: elif "codellama/codellama" in model or "togethercomputer/codellama" in model:
if "instruct" in model: if "instruct" in model:
return llama_2_chat_pt(messages=messages) # https://huggingface.co/blog/codellama#conversational-instructions return llama_2_chat_pt(
messages=messages
) # https://huggingface.co/blog/codellama#conversational-instructions
elif "wizardlm/wizardcoder" in model: elif "wizardlm/wizardcoder" in model:
return wizardcoder_pt(messages=messages) return wizardcoder_pt(messages=messages)
elif "phind/phind-codellama" in model: elif "phind/phind-codellama" in model:
return phind_codellama_pt(messages=messages) return phind_codellama_pt(messages=messages)
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model): elif "togethercomputer/llama-2" in model and (
"instruct" in model or "chat" in model
):
return llama_2_chat_pt(messages=messages) return llama_2_chat_pt(messages=messages)
elif model in ["gryphe/mythomax-l2-13b", "gryphe/mythomix-l2-13b", "gryphe/mythologic-l2-13b"]: elif model in [
"gryphe/mythomax-l2-13b",
"gryphe/mythomix-l2-13b",
"gryphe/mythologic-l2-13b",
]:
return alpaca_pt(messages=messages) return alpaca_pt(messages=messages)
else: else:
return hf_chat_template(original_model_name, messages) return hf_chat_template(original_model_name, messages)
except Exception as e: except Exception as e:
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) return default_pt(
messages=messages
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

View file

@ -8,17 +8,21 @@ import litellm
import httpx import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class ReplicateError(Exception): class ReplicateError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.replicate.com/v1/deployments") self.request = httpx.Request(
method="POST", url="https://api.replicate.com/v1/deployments"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class ReplicateConfig():
class ReplicateConfig:
""" """
Reference: https://replicate.com/meta/llama-2-70b-chat/api Reference: https://replicate.com/meta/llama-2-70b-chat/api
- `prompt` (string): The prompt to send to the model. - `prompt` (string): The prompt to send to the model.
@ -43,42 +47,57 @@ class ReplicateConfig():
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models. Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
""" """
system_prompt: Optional[str]=None
max_new_tokens: Optional[int]=None
min_new_tokens: Optional[int]=None
temperature: Optional[int]=None
top_p: Optional[int]=None
top_k: Optional[int]=None
stop_sequences: Optional[str]=None
seed: Optional[int]=None
debug: Optional[bool]=None
def __init__(self, system_prompt: Optional[str] = None
system_prompt: Optional[str]=None, max_new_tokens: Optional[int] = None
max_new_tokens: Optional[int]=None, min_new_tokens: Optional[int] = None
min_new_tokens: Optional[int]=None, temperature: Optional[int] = None
temperature: Optional[int]=None, top_p: Optional[int] = None
top_p: Optional[int]=None, top_k: Optional[int] = None
top_k: Optional[int]=None, stop_sequences: Optional[str] = None
stop_sequences: Optional[str]=None, seed: Optional[int] = None
seed: Optional[int]=None, debug: Optional[bool] = None
debug: Optional[bool]=None) -> None:
def __init__(
self,
system_prompt: Optional[str] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
stop_sequences: Optional[str] = None,
seed: Optional[int] = None,
debug: Optional[bool] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# Function to start a prediction and get the prediction URL # Function to start a prediction and get the prediction URL
def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose): def start_prediction(
version_id, input_data, api_token, api_base, logging_obj, print_verbose
):
base_url = api_base base_url = api_base
if "deployments" in version_id: if "deployments" in version_id:
print_verbose("\nLiteLLM: Request to custom replicate deployment") print_verbose("\nLiteLLM: Request to custom replicate deployment")
@ -88,7 +107,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
headers = { headers = {
"Authorization": f"Token {api_token}", "Authorization": f"Token {api_token}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
initial_prediction_data = { initial_prediction_data = {
@ -100,22 +119,31 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj, p
logging_obj.pre_call( logging_obj.pre_call(
input=input_data["prompt"], input=input_data["prompt"],
api_key="", api_key="",
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers, "api_base": base_url}, additional_args={
"complete_input_dict": initial_prediction_data,
"headers": headers,
"api_base": base_url,
},
) )
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers) response = requests.post(
f"{base_url}/predictions", json=initial_prediction_data, headers=headers
)
if response.status_code == 201: if response.status_code == 201:
response_data = response.json() response_data = response.json()
return response_data.get("urls", {}).get("get") return response_data.get("urls", {}).get("get")
else: else:
raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}") raise ReplicateError(
response.status_code, f"Failed to start prediction {response.text}"
)
# Function to handle prediction response (non-streaming) # Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose): def handle_prediction_response(prediction_url, api_token, print_verbose):
output_string = "" output_string = ""
headers = { headers = {
"Authorization": f"Token {api_token}", "Authorization": f"Token {api_token}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
status = "" status = ""
@ -127,18 +155,22 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()
if "output" in response_data: if "output" in response_data:
output_string = "".join(response_data['output']) output_string = "".join(response_data["output"])
print_verbose(f"Non-streamed output:{output_string}") print_verbose(f"Non-streamed output:{output_string}")
status = response_data.get('status', None) status = response_data.get("status", None)
logs = response_data.get("logs", "") logs = response_data.get("logs", "")
if status == "failed": if status == "failed":
replicate_error = response_data.get("error", "") replicate_error = response_data.get("error", "")
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}, \nReplicate logs:{logs}") raise ReplicateError(
status_code=400,
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
)
else: else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose("Replicate: Failed to fetch prediction status and output.") print_verbose("Replicate: Failed to fetch prediction status and output.")
return output_string, logs return output_string, logs
# Function to handle prediction response (streaming) # Function to handle prediction response (streaming)
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose): def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
previous_output = "" previous_output = ""
@ -146,7 +178,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
headers = { headers = {
"Authorization": f"Token {api_token}", "Authorization": f"Token {api_token}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
status = "" status = ""
while True and (status not in ["succeeded", "failed", "canceled"]): while True and (status not in ["succeeded", "failed", "canceled"]):
@ -155,20 +187,24 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
response = requests.get(prediction_url, headers=headers) response = requests.get(prediction_url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()
status = response_data['status'] status = response_data["status"]
if "output" in response_data: if "output" in response_data:
output_string = "".join(response_data['output']) output_string = "".join(response_data["output"])
new_output = output_string[len(previous_output):] new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}") print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status} yield {"output": new_output, "status": status}
previous_output = output_string previous_output = output_string
status = response_data['status'] status = response_data["status"]
if status == "failed": if status == "failed":
replicate_error = response_data.get("error", "") replicate_error = response_data.get("error", "")
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}") raise ReplicateError(
status_code=400, message=f"Error: {replicate_error}"
)
else: else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}") print_verbose(
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
)
# Function to extract version ID from model string # Function to extract version ID from model string
@ -178,6 +214,7 @@ def model_to_version_id(model):
return split_model[1] return split_model[1]
return model return model
# Main function for prediction completion # Main function for prediction completion
def completion( def completion(
model: str, model: str,
@ -198,7 +235,9 @@ def completion(
## Load Config ## Load Config
config = litellm.ReplicateConfig.get_config() config = litellm.ReplicateConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
system_prompt = None system_prompt = None
@ -233,38 +272,53 @@ def completion(
input_data = { input_data = {
"prompt": prompt, "prompt": prompt,
"system_prompt": system_prompt, "system_prompt": system_prompt,
**optional_params **optional_params,
} }
# Otherwise, use the prompt as is # Otherwise, use the prompt as is
else: else:
input_data = { input_data = {"prompt": prompt, **optional_params}
"prompt": prompt,
**optional_params
}
## COMPLETION CALL ## COMPLETION CALL
## Replicate Compeltion calls have 2 steps ## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url ## Step1: Start Prediction: gets a prediction url
## Step2: Poll prediction url for response ## Step2: Poll prediction url for response
## Step2: is handled with and without streaming ## Step2: is handled with and without streaming
model_response["created"] = int(time.time()) # for pricing this must remain right before calling api model_response["created"] = int(
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose) time.time()
) # for pricing this must remain right before calling api
prediction_url = start_prediction(
version_id,
input_data,
api_key,
api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
)
print_verbose(prediction_url) print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming) # Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
print_verbose("streaming request") print_verbose("streaming request")
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose) return handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
)
else: else:
result, logs = handle_prediction_response(prediction_url, api_key, print_verbose) result, logs = handle_prediction_response(
model_response["ended"] = time.time() # for pricing this must remain right after calling api prediction_url, api_key, print_verbose
)
model_response[
"ended"
] = time.time() # for pricing this must remain right after calling api
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key="", api_key="",
original_response=result, original_response=result,
additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, }, additional_args={
"complete_input_dict": input_data,
"logs": logs,
"api_base": prediction_url,
},
) )
print_verbose(f"raw model_response: {result}") print_verbose(f"raw model_response: {result}")
@ -278,12 +332,14 @@ def completion(
# Calculate usage # Calculate usage
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"].get("content", ""))) completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["model"] = "replicate/" + model model_response["model"] = "replicate/" + model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response

View file

@ -11,41 +11,60 @@ from copy import deepcopy
import httpx import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class SagemakerError(Exception): class SagemakerError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker") self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class SagemakerConfig():
class SagemakerConfig:
""" """
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
""" """
max_new_tokens: Optional[int]=None
top_p: Optional[float]=None
temperature: Optional[float]=None
return_full_text: Optional[bool]=None
def __init__(self, max_new_tokens: Optional[int] = None
max_new_tokens: Optional[int]=None, top_p: Optional[float] = None
top_p: Optional[float]=None, temperature: Optional[float] = None
temperature: Optional[float]=None, return_full_text: Optional[bool] = None
return_full_text: Optional[bool]=None) -> None:
def __init__(
self,
max_new_tokens: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
return_full_text: Optional[bool] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
""" """
SAGEMAKER AUTH Keys/Vars SAGEMAKER AUTH Keys/Vars
@ -55,6 +74,7 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
# set os.environ['AWS_REGION_NAME'] = <your-region_name> # set os.environ['AWS_REGION_NAME'] = <your-region_name>
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -91,8 +111,8 @@ def completion(
# we need to read region name from env # we need to read region name from env
# I assume majority of users use .env for auth # I assume majority of users use .env for auth
region_name = ( region_name = (
get_secret("AWS_REGION_NAME") or get_secret("AWS_REGION_NAME")
"us-west-2" # default to us-west-2 if user not specified or "us-west-2" # default to us-west-2 if user not specified
) )
client = boto3.client( client = boto3.client(
service_name="sagemaker-runtime", service_name="sagemaker-runtime",
@ -106,7 +126,9 @@ def completion(
## Load Config ## Load Config
config = litellm.SagemakerConfig.get_config() config = litellm.SagemakerConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
model = model model = model
@ -117,7 +139,7 @@ def completion(
role_dict=model_prompt_details.get("roles", None), role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages messages=messages,
) )
else: else:
if hf_model_name is None: if hf_model_name is None:
@ -126,13 +148,14 @@ def completion(
hf_model_name = "meta-llama/Llama-2-7b-chat-hf" hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
else: # apply regular llama2 template else: # apply regular llama2 template
hf_model_name = "meta-llama/Llama-2-7b" hf_model_name = "meta-llama/Llama-2-7b"
hf_model_name = hf_model_name or model # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages) prompt = prompt_factory(model=hf_model_name, messages=messages)
data = json.dumps({ data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
"inputs": prompt, "utf-8"
"parameters": inference_params )
}).encode('utf-8')
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
@ -146,7 +169,11 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str, "hf_model_name": hf_model_name}, additional_args={
"complete_input_dict": data,
"request_str": request_str,
"hf_model_name": hf_model_name,
},
) )
## COMPLETION CALL ## COMPLETION CALL
try: try:
@ -184,12 +211,13 @@ def completion(
model_response["choices"][0]["message"]["content"] = completion_output model_response["choices"][0]["message"]["content"] = completion_output
except: except:
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500) raise SagemakerError(
message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
status_code=500,
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(encoding.encode(prompt))
encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
@ -199,12 +227,14 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(model: str,
def embedding(
model: str,
input: list, input: list,
model_response: EmbeddingResponse, model_response: EmbeddingResponse,
print_verbose: Callable, print_verbose: Callable,
@ -213,12 +243,14 @@ def embedding(model: str,
custom_prompt_dict={}, custom_prompt_dict={},
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None): logger_fn=None,
):
""" """
Supports Huggingface Jumpstart embeddings like GPT-6B Supports Huggingface Jumpstart embeddings like GPT-6B
""" """
### BOTO3 INIT ### BOTO3 INIT
import boto3 import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None)
@ -240,8 +272,8 @@ def embedding(model: str,
# we need to read region name from env # we need to read region name from env
# I assume majority of users use .env for auth # I assume majority of users use .env for auth
region_name = ( region_name = (
get_secret("AWS_REGION_NAME") or get_secret("AWS_REGION_NAME")
"us-west-2" # default to us-west-2 if user not specified or "us-west-2" # default to us-west-2 if user not specified
) )
client = boto3.client( client = boto3.client(
service_name="sagemaker-runtime", service_name="sagemaker-runtime",
@ -255,13 +287,13 @@ def embedding(model: str,
## Load Config ## Load Config
config = litellm.SagemakerConfig.get_config() config = litellm.SagemakerConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in inference_params
): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
#### HF EMBEDDING LOGIC #### HF EMBEDDING LOGIC
data = json.dumps({ data = json.dumps({"text_inputs": input}).encode("utf-8")
"text_inputs": input
}).encode('utf-8')
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
@ -295,7 +327,6 @@ def embedding(model: str,
original_response=response, original_response=response,
) )
response = json.loads(response["Body"].read().decode("utf8")) response = json.loads(response["Body"].read().decode("utf8"))
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -307,20 +338,17 @@ def embedding(model: str,
print_verbose(f"raw model_response: {response}") print_verbose(f"raw model_response: {response}")
if "embedding" not in response: if "embedding" not in response:
raise SagemakerError(status_code=500, message="embedding not found in response") raise SagemakerError(status_code=500, message="embedding not found in response")
embeddings = response['embedding'] embeddings = response["embedding"]
if not isinstance(embeddings, list): if not isinstance(embeddings, list):
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}") raise SagemakerError(
status_code=422, message=f"Response not in expected format - {embeddings}"
)
output_data = [] output_data = []
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
output_data.append( output_data.append(
{ {"object": "embedding", "index": idx, "embedding": embedding}
"object": "embedding",
"index": idx,
"embedding": embedding
}
) )
model_response["object"] = "list" model_response["object"] = "list"
@ -329,8 +357,10 @@ def embedding(model: str,
input_tokens = 0 input_tokens = 0
for text in input: for text in input:
input_tokens+=len(encoding.encode(text)) input_tokens += len(encoding.encode(text))
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens) model_response["usage"] = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
return model_response return model_response

View file

@ -9,17 +9,21 @@ import httpx
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
class TogetherAIError(Exception): class TogetherAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.together.xyz/inference") self.request = httpx.Request(
method="POST", url="https://api.together.xyz/inference"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class TogetherAIConfig():
class TogetherAIConfig:
""" """
Reference: https://docs.together.ai/reference/inference Reference: https://docs.together.ai/reference/inference
@ -39,33 +43,47 @@ class TogetherAIConfig():
- `logprobs` (int32, optional): This parameter is not described in the prompt. - `logprobs` (int32, optional): This parameter is not described in the prompt.
""" """
max_tokens: Optional[int]=None
stop: Optional[str]=None
temperature:Optional[int]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
repetition_penalty: Optional[float]=None
logprobs: Optional[int]=None
def __init__(self, max_tokens: Optional[int] = None
max_tokens: Optional[int]=None, stop: Optional[str] = None
stop: Optional[str]=None, temperature: Optional[int] = None
temperature:Optional[int]=None, top_p: Optional[float] = None
top_p: Optional[float]=None, top_k: Optional[int] = None
top_k: Optional[int]=None, repetition_penalty: Optional[float] = None
repetition_penalty: Optional[float]=None, logprobs: Optional[int] = None
logprobs: Optional[int]=None) -> None:
def __init__(
self,
max_tokens: Optional[int] = None,
stop: Optional[str] = None,
temperature: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
logprobs: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def validate_environment(api_key): def validate_environment(api_key):
@ -80,6 +98,7 @@ def validate_environment(api_key):
} }
return headers return headers
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -99,7 +118,9 @@ def completion(
## Load Config ## Load Config
config = litellm.TogetherAIConfig.get_config() config = litellm.TogetherAIConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}") print_verbose(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
@ -115,7 +136,12 @@ def completion(
messages=messages, messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages, api_key=api_key, custom_llm_provider="together_ai") # api key required to query together ai model list prompt = prompt_factory(
model=model,
messages=messages,
api_key=api_key,
custom_llm_provider="together_ai",
) # api key required to query together ai model list
data = { data = {
"model": model, "model": model,
@ -128,13 +154,14 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data, "headers": headers, "api_base": api_base}, additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
) )
## COMPLETION CALL ## COMPLETION CALL
if ( if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
"stream_tokens" in optional_params
and optional_params["stream_tokens"] == True
):
response = requests.post( response = requests.post(
api_base, api_base,
headers=headers, headers=headers,
@ -143,11 +170,7 @@ def completion(
) )
return response.iter_lines() return response.iter_lines()
else: else:
response = requests.post( response = requests.post(api_base, headers=headers, data=json.dumps(data))
api_base,
headers=headers,
data=json.dumps(data)
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
@ -170,30 +193,38 @@ def completion(
) )
elif "error" in completion_response["output"]: elif "error" in completion_response["output"]:
raise TogetherAIError( raise TogetherAIError(
message=json.dumps(completion_response["output"]), status_code=response.status_code message=json.dumps(completion_response["output"]),
status_code=response.status_code,
) )
if len(completion_response["output"]["choices"][0]["text"]) >= 0: if len(completion_response["output"]["choices"][0]["text"]) >= 0:
model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"] model_response["choices"][0]["message"]["content"] = completion_response[
"output"
]["choices"][0]["text"]
## CALCULATING USAGE ## CALCULATING USAGE
print_verbose(f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}") print_verbose(
f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}"
)
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) )
if "finish_reason" in completion_response["output"]["choices"][0]: if "finish_reason" in completion_response["output"]["choices"][0]:
model_response.choices[0].finish_reason = completion_response["output"]["choices"][0]["finish_reason"] model_response.choices[0].finish_reason = completion_response["output"][
"choices"
][0]["finish_reason"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -8,17 +8,21 @@ from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm import litellm
import httpx import httpx
class VertexAIError(Exception): class VertexAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url=" https://cloud.google.com/vertex-ai/") self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class VertexAIConfig():
class VertexAIConfig:
""" """
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
@ -34,28 +38,42 @@ class VertexAIConfig():
Note: Please make sure to modify the default parameters as required for your use case. Note: Please make sure to modify the default parameters as required for your use case.
""" """
temperature: Optional[float]=None
max_output_tokens: Optional[int]=None
top_p: Optional[float]=None
top_k: Optional[int]=None
def __init__(self, temperature: Optional[float] = None
temperature: Optional[float]=None, max_output_tokens: Optional[int] = None
max_output_tokens: Optional[int]=None, top_p: Optional[float] = None
top_p: Optional[float]=None, top_k: Optional[int] = None
top_k: Optional[int]=None) -> None:
def __init__(
self,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key != 'self' and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@classmethod @classmethod
def get_config(cls): def get_config(cls):
return {k: v for k, v in cls.__dict__.items() return {
if not k.startswith('__') k: v
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) for k, v in cls.__dict__.items()
and v is not None} if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def _get_image_bytes_from_url(image_url: str) -> bytes: def _get_image_bytes_from_url(image_url: str) -> bytes:
try: try:
@ -65,7 +83,7 @@ def _get_image_bytes_from_url(image_url: str) -> bytes:
return image_bytes return image_bytes
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
# Handle any request exceptions (e.g., connection error, timeout) # Handle any request exceptions (e.g., connection error, timeout)
return b'' # Return an empty bytes object or handle the error as needed return b"" # Return an empty bytes object or handle the error as needed
def _load_image_from_url(image_url: str): def _load_image_from_url(image_url: str):
@ -78,13 +96,18 @@ def _load_image_from_url(image_url: str):
Returns: Returns:
Image: The loaded image. Image: The loaded image.
""" """
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
image_bytes = _get_image_bytes_from_url(image_url) image_bytes = _get_image_bytes_from_url(image_url)
return Image.from_bytes(image_bytes) return Image.from_bytes(image_bytes)
def _gemini_vision_convert_messages(
messages: list def _gemini_vision_convert_messages(messages: list):
):
""" """
Converts given messages for GPT-4 Vision to Gemini format. Converts given messages for GPT-4 Vision to Gemini format.
@ -115,11 +138,23 @@ def _gemini_vision_convert_messages(
try: try:
import vertexai import vertexai
except: except:
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`") raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
try: try:
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
InputOutputTextPair,
)
from vertexai.language_models import TextGenerationModel, CodeGenerationModel from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
Image,
)
# given messages for gpt-4 vision, convert them for gemini # given messages for gpt-4 vision, convert them for gemini
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb # https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
@ -159,6 +194,7 @@ def _gemini_vision_convert_messages(
except Exception as e: except Exception as e:
raise e raise e
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -171,22 +207,30 @@ def completion(
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
acompletion: bool=False acompletion: bool = False,
): ):
try: try:
import vertexai import vertexai
except: except:
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`") raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
try: try:
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
InputOutputTextPair,
)
from vertexai.language_models import TextGenerationModel, CodeGenerationModel from vertexai.language_models import TextGenerationModel, CodeGenerationModel
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig from vertexai.preview.generative_models import (
GenerativeModel,
Part,
GenerationConfig,
)
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
vertexai.init(project=vertex_project, location=vertex_location)
vertexai.init(
project=vertex_project, location=vertex_location
)
## Load Config ## Load Config
config = litellm.VertexAIConfig.get_config() config = litellm.VertexAIConfig.get_config()
@ -202,11 +246,19 @@ def completion(
raise ValueError("safety_settings must be a list") raise ValueError("safety_settings must be a list")
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
raise ValueError("safety_settings must be a list of dicts") raise ValueError("safety_settings must be a list of dicts")
safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings] safety_settings = [
gapic_content_types.SafetySetting(x) for x in safety_settings
]
# vertexai does not use an API key, it looks for credentials.json in the environment # vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)]) prompt = " ".join(
[
message["content"]
for message in messages
if isinstance(message["content"], str)
]
)
mode = "" mode = ""
@ -240,23 +292,68 @@ def completion(
if acompletion == True: # [TODO] expand support to vertex ai chat + text models if acompletion == True: # [TODO] expand support to vertex ai chat + text models
if optional_params.get("stream", False) is True: if optional_params.get("stream", False) is True:
# async streaming # async streaming
return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, messages=messages, print_verbose=print_verbose, **optional_params) return async_streaming(
return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, messages=messages,print_verbose=print_verbose,**optional_params) llm_model=llm_model,
mode=mode,
prompt=prompt,
logging_obj=logging_obj,
request_str=request_str,
model=model,
model_response=model_response,
messages=messages,
print_verbose=print_verbose,
**optional_params,
)
return async_completion(
llm_model=llm_model,
mode=mode,
prompt=prompt,
logging_obj=logging_obj,
request_str=request_str,
model=model,
model_response=model_response,
encoding=encoding,
messages=messages,
print_verbose=print_verbose,
**optional_params,
)
if mode == "": if mode == "":
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
model_response = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream) input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
stream=stream,
)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n" request_str += f"llm_model.generate_content({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
response_obj = llm_model.generate_content(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings) input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = llm_model.generate_content(
prompt,
generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings,
)
completion_response = response_obj.text completion_response = response_obj.text
response_obj = response_obj._raw_response response_obj = response_obj._raw_response
elif mode == "vision": elif mode == "vision":
@ -268,20 +365,34 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.generate_content( model_response = llm_model.generate_content(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=GenerationConfig(**optional_params),
safety_settings=safety_settings, safety_settings=safety_settings,
stream=True stream=True,
) )
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"response = llm_model.generate_content({content})\n" request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
## LLM Call ## LLM Call
response = llm_model.generate_content( response = llm_model.generate_content(
@ -293,37 +404,73 @@ def completion(
response_obj = response._raw_response response_obj = response._raw_response
elif mode == "chat": elif mode == "chat":
chat = llm_model.start_chat() chat = llm_model.start_chat()
request_str+= f"chat = llm_model.start_chat()\n" request_str += f"chat = llm_model.start_chat()\n"
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error, # NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request # we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params optional_params.pop(
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n" "stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming({prompt}, **{optional_params})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = chat.send_message_streaming(prompt, **optional_params) model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = chat.send_message(prompt, **optional_params).text completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text": elif mode == "text":
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai optional_params.pop(
request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n" "stream", None
) # See note above on handling streaming for vertex ai
request_str += (
f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
model_response = llm_model.predict_streaming(prompt, **optional_params) model_response = llm_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True optional_params["stream"] = True
return model_response return model_response
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
completion_response = llm_model.predict(prompt, **optional_params).text completion_response = llm_model.predict(prompt, **optional_params).text
## LOGGING ## LOGGING
@ -333,36 +480,53 @@ def completion(
## RESPONSE OBJECT ## RESPONSE OBJECT
if len(str(completion_response)) > 0: if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = str(
"content" completion_response
] = str(completion_response) )
model_response["choices"][0]["message"]["content"] = str(completion_response) model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None: if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name model_response["choices"][0].finish_reason = response_obj.candidates[
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count, 0
].finish_reason.name
usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count, completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count) total_tokens=response_obj.usage_metadata.total_token_count,
else:
prompt_tokens = len(
encoding.encode(prompt)
) )
else:
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) )
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, messages = None, print_verbose = None, **optional_params):
async def async_completion(
llm_model,
mode: str,
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
encoding=None,
messages=None,
print_verbose=None,
**optional_params,
):
""" """
Add support for acompletion calls for gemini-pro Add support for acompletion calls for gemini-pro
""" """
@ -373,8 +537,17 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
# gemini-pro # gemini-pro
chat = llm_model.start_chat() chat = llm_model.start_chat()
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params)) input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params)
)
completion_response = response_obj.text completion_response = response_obj.text
response_obj = response_obj._raw_response response_obj = response_obj._raw_response
elif mode == "vision": elif mode == "vision":
@ -386,12 +559,18 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
request_str += f"response = llm_model.generate_content({content})\n" request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
## LLM Call ## LLM Call
response = await llm_model._generate_content_async( response = await llm_model._generate_content_async(
contents=content, contents=content, generation_config=GenerationConfig(**optional_params)
generation_config=GenerationConfig(**optional_params)
) )
completion_response = response.text completion_response = response.text
response_obj = response._raw_response response_obj = response._raw_response
@ -399,14 +578,28 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
# chat-bison etc. # chat-bison etc.
chat = llm_model.start_chat() chat = llm_model.start_chat()
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await chat.send_message_async(prompt, **optional_params) response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text completion_response = response_obj.text
elif mode == "text": elif mode == "text":
# gecko etc. # gecko etc.
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response_obj = await llm_model.predict_async(prompt, **optional_params) response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text completion_response = response_obj.text
@ -417,48 +610,74 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
## RESPONSE OBJECT ## RESPONSE OBJECT
if len(str(completion_response)) > 0: if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"]["content"] = str(
"content" completion_response
] = str(completion_response) )
model_response["choices"][0]["message"]["content"] = str(completion_response) model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None: if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name model_response["choices"][0].finish_reason = response_obj.candidates[
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count, 0
].finish_reason.name
usage = Usage(
prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count, completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count) total_tokens=response_obj.usage_metadata.total_token_count,
else:
prompt_tokens = len(
encoding.encode(prompt)
) )
else:
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", "")) encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) )
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, messages = None, print_verbose = None, **optional_params):
async def async_streaming(
llm_model,
mode: str,
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
messages=None,
print_verbose=None,
**optional_params,
):
""" """
Add support for async streaming calls for gemini-pro Add support for async streaming calls for gemini-pro
""" """
from vertexai.preview.generative_models import GenerationConfig from vertexai.preview.generative_models import GenerationConfig
if mode == "": if mode == "":
# gemini-pro # gemini-pro
chat = llm_model.start_chat() chat = llm_model.start_chat()
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream) input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await chat.send_message_async(
prompt, generation_config=GenerationConfig(**optional_params), stream=stream
)
optional_params["stream"] = True optional_params["stream"] = True
elif mode == "vision": elif mode == "vision":
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
@ -470,33 +689,68 @@ async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_r
content = [prompt] + images content = [prompt] + images
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n" request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = llm_model._generate_content_streaming_async( response = llm_model._generate_content_streaming_async(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=GenerationConfig(**optional_params),
stream=True stream=True,
) )
optional_params["stream"] = True optional_params["stream"] = True
elif mode == "chat": elif mode == "chat":
chat = llm_model.start_chat() chat = llm_model.start_chat()
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params optional_params.pop(
request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n" "stream", None
) # vertex ai raises an error when passing stream in optional params
request_str += (
f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = chat.send_message_streaming_async(prompt, **optional_params) response = chat.send_message_streaming_async(prompt, **optional_params)
optional_params["stream"] = True optional_params["stream"] = True
elif mode == "text": elif mode == "text":
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai optional_params.pop(
request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n" "stream", None
) # See note above on handling streaming for vertex ai
request_str += (
f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str}) logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = llm_model.predict_streaming_async(prompt, **optional_params) response = llm_model.predict_streaming_async(prompt, **optional_params)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="vertex_ai",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -6,7 +6,10 @@ import time, httpx
from typing import Callable, Any from typing import Callable, Any
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
llm = None llm = None
class VLLMError(Exception): class VLLMError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -17,17 +20,20 @@ class VLLMError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# check if vllm is installed # check if vllm is installed
def validate_environment(model: str): def validate_environment(model: str):
global llm global llm
try: try:
from vllm import LLM, SamplingParams # type: ignore from vllm import LLM, SamplingParams # type: ignore
if llm is None: if llm is None:
llm = LLM(model=model) llm = LLM(model=model)
return llm, SamplingParams return llm, SamplingParams
except Exception as e: except Exception as e:
raise VLLMError(status_code=0, message=str(e)) raise VLLMError(status_code=0, message=str(e))
def completion( def completion(
model: str, model: str,
messages: list, messages: list,
@ -53,12 +59,11 @@ def completion(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -69,8 +74,9 @@ def completion(
if llm: if llm:
outputs = llm.generate(prompt, sampling_params) outputs = llm.generate(prompt, sampling_params)
else: else:
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm") raise VLLMError(
status_code=0, message="Need to pass in a model name to initialize vllm"
)
## COMPLETION CALL ## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
@ -96,16 +102,14 @@ def completion(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response
def batch_completions( def batch_completions(
model: str, model: str, messages: list, optional_params=None, custom_prompt_dict={}
messages: list,
optional_params=None,
custom_prompt_dict={}
): ):
""" """
Example usage: Example usage:
@ -150,7 +154,7 @@ def batch_completions(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=message messages=message,
) )
prompts.append(prompt) prompts.append(prompt)
else: else:
@ -161,7 +165,9 @@ def batch_completions(
if llm: if llm:
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
else: else:
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm") raise VLLMError(
status_code=0, message="Need to pass in a model name to initialize vllm"
)
final_outputs = [] final_outputs = []
for output in outputs: for output in outputs:
@ -178,12 +184,13 @@ def batch_completions(
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
final_outputs.append(model_response) final_outputs.append(model_response)
return final_outputs return final_outputs
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

File diff suppressed because it is too large Load diff

View file

@ -3,10 +3,12 @@ from typing import Optional, List, Union, Dict, Literal
from datetime import datetime from datetime import datetime
import uuid, json import uuid, json
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
""" """
Implements default functions, all pydantic objects should have. Implements default functions, all pydantic objects should have.
""" """
def json(self, **kwargs): def json(self, **kwargs):
try: try:
return self.model_dump() # noqa return self.model_dump() # noqa
@ -49,7 +51,8 @@ class ProxyChatCompletionRequest(LiteLLMBase):
request_timeout: Optional[int] = None request_timeout: Optional[int] = None
class Config: class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(LiteLLMBase): class ModelInfoDelete(LiteLLMBase):
id: Optional[str] id: Optional[str]
@ -57,21 +60,21 @@ class ModelInfoDelete(LiteLLMBase):
class ModelInfo(LiteLLMBase): class ModelInfo(LiteLLMBase):
id: Optional[str] id: Optional[str]
mode: Optional[Literal['embedding', 'chat', 'completion']] mode: Optional[Literal["embedding", "chat", "completion"]]
input_cost_per_token: Optional[float] = 0.0 input_cost_per_token: Optional[float] = 0.0
output_cost_per_token: Optional[float] = 0.0 output_cost_per_token: Optional[float] = 0.0
max_tokens: Optional[int] = 2048 # assume 2048 if not set max_tokens: Optional[int] = 2048 # assume 2048 if not set
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model # for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
# we look up the base model in model_prices_and_context_window.json # we look up the base model in model_prices_and_context_window.json
base_model: Optional[Literal base_model: Optional[
[ Literal[
'gpt-4-1106-preview', "gpt-4-1106-preview",
'gpt-4-32k', "gpt-4-32k",
'gpt-4', "gpt-4",
'gpt-3.5-turbo-16k', "gpt-3.5-turbo-16k",
'gpt-3.5-turbo', "gpt-3.5-turbo",
'text-embedding-ada-002', "text-embedding-ada-002",
] ]
] ]
@ -79,7 +82,6 @@ class ModelInfo(LiteLLMBase):
extra = Extra.allow # Allow extra fields extra = Extra.allow # Allow extra fields
protected_namespaces = () protected_namespaces = ()
@root_validator(pre=True) @root_validator(pre=True)
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("id") is None: if values.get("id") is None:
@ -97,7 +99,6 @@ class ModelInfo(LiteLLMBase):
return values return values
class ModelParams(LiteLLMBase): class ModelParams(LiteLLMBase):
model_name: str model_name: str
litellm_params: dict litellm_params: dict
@ -112,6 +113,7 @@ class ModelParams(LiteLLMBase):
values.update({"model_info": ModelInfo()}) values.update({"model_info": ModelInfo()})
return values return values
class GenerateKeyRequest(LiteLLMBase): class GenerateKeyRequest(LiteLLMBase):
duration: Optional[str] = "1h" duration: Optional[str] = "1h"
models: Optional[list] = [] models: Optional[list] = []
@ -122,6 +124,7 @@ class GenerateKeyRequest(LiteLLMBase):
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {} metadata: Optional[dict] = {}
class UpdateKeyRequest(LiteLLMBase): class UpdateKeyRequest(LiteLLMBase):
key: str key: str
duration: Optional[str] = None duration: Optional[str] = None
@ -133,10 +136,12 @@ class UpdateKeyRequest(LiteLLMBase):
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {} metadata: Optional[dict] = {}
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
""" """
Return the row in the db Return the row in the db
""" """
api_key: Optional[str] = None api_key: Optional[str] = None
models: list = [] models: list = []
aliases: dict = {} aliases: dict = {}
@ -147,45 +152,84 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api k
duration: str = "1h" duration: str = "1h"
metadata: dict = {} metadata: dict = {}
class GenerateKeyResponse(LiteLLMBase): class GenerateKeyResponse(LiteLLMBase):
key: str key: str
expires: Optional[datetime] expires: Optional[datetime]
user_id: str user_id: str
class _DeleteKeyObject(LiteLLMBase): class _DeleteKeyObject(LiteLLMBase):
key: str key: str
class DeleteKeyRequest(LiteLLMBase): class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject] keys: List[_DeleteKeyObject]
class NewUserRequest(GenerateKeyRequest): class NewUserRequest(GenerateKeyRequest):
max_budget: Optional[float] = None max_budget: Optional[float] = None
class NewUserResponse(GenerateKeyResponse): class NewUserResponse(GenerateKeyResponse):
max_budget: Optional[float] = None max_budget: Optional[float] = None
class ConfigGeneralSettings(LiteLLMBase): class ConfigGeneralSettings(LiteLLMBase):
""" """
Documents all the fields supported by `general_settings` in config.yaml Documents all the fields supported by `general_settings` in config.yaml
""" """
completion_model: Optional[str] = Field(None, description="proxy level default model for all chat completion calls")
use_azure_key_vault: Optional[bool] = Field(None, description="load keys from azure key vault") completion_model: Optional[str] = Field(
master_key: Optional[str] = Field(None, description="require a key for all calls to proxy") None, description="proxy level default model for all chat completion calls"
database_url: Optional[str] = Field(None, description="connect to a postgres db - needed for generating temporary keys + tracking spend / key") )
otel: Optional[bool] = Field(None, description="[BETA] OpenTelemetry support - this might change, use with caution.") use_azure_key_vault: Optional[bool] = Field(
custom_auth: Optional[str] = Field(None, description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth") None, description="load keys from azure key vault"
max_parallel_requests: Optional[int] = Field(None, description="maximum parallel requests for each api key") )
infer_model_from_keys: Optional[bool] = Field(None, description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)") master_key: Optional[str] = Field(
background_health_checks: Optional[bool] = Field(None, description="run health checks in background") None, description="require a key for all calls to proxy"
health_check_interval: int = Field(300, description="background health check interval in seconds") )
database_url: Optional[str] = Field(
None,
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
)
otel: Optional[bool] = Field(
None,
description="[BETA] OpenTelemetry support - this might change, use with caution.",
)
custom_auth: Optional[str] = Field(
None,
description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth",
)
max_parallel_requests: Optional[int] = Field(
None, description="maximum parallel requests for each api key"
)
infer_model_from_keys: Optional[bool] = Field(
None,
description="for `/models` endpoint, infers available model based on environment keys (e.g. OPENAI_API_KEY)",
)
background_health_checks: Optional[bool] = Field(
None, description="run health checks in background"
)
health_check_interval: int = Field(
300, description="background health check interval in seconds"
)
class ConfigYAML(LiteLLMBase): class ConfigYAML(LiteLLMBase):
""" """
Documents all the fields supported by the config.yaml Documents all the fields supported by the config.yaml
""" """
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache") model_list: Optional[List[ModelParams]] = Field(
None,
description="List of supported models on the server, with model-specific configs",
)
litellm_settings: Optional[dict] = Field(
None,
description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache",
)
general_settings: Optional[ConfigGeneralSettings] = None general_settings: Optional[ConfigGeneralSettings] = None
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

View file

@ -4,6 +4,8 @@ from dotenv import load_dotenv
import os import os
load_dotenv() load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try: try:
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234" modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"

View file

@ -10,12 +10,14 @@ from litellm.integrations.custom_logger import CustomLogger
import litellm import litellm
import inspect import inspect
# This file includes the custom callbacks for LiteLLM Proxy # This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml # Once defined, these can be passed in proxy_config.yaml
def print_verbose(print_statement): def print_verbose(print_statement):
if litellm.set_verbose: if litellm.set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
class MyCustomHandler(CustomLogger): class MyCustomHandler(CustomLogger):
def __init__(self): def __init__(self):
blue_color_code = "\033[94m" blue_color_code = "\033[94m"
@ -23,7 +25,11 @@ class MyCustomHandler(CustomLogger):
print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger") print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger")
try: try:
print_verbose(f"Logger Initialized with following methods:") print_verbose(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))] methods = [
method
for method in dir(self)
if inspect.ismethod(getattr(self, method))
]
# Pretty print_verbose the methods # Pretty print_verbose the methods
for method in methods: for method in methods:
@ -32,7 +38,6 @@ class MyCustomHandler(CustomLogger):
except: except:
pass pass
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
print_verbose(f"Pre-API Call") print_verbose(f"Pre-API Call")
@ -45,7 +50,6 @@ class MyCustomHandler(CustomLogger):
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose("On Success!") print_verbose("On Success!")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Async Success!") print_verbose(f"On Async Success!")
response_cost = litellm.completion_cost(completion_response=response_obj) response_cost = litellm.completion_cost(completion_response=response_obj)

View file

@ -12,25 +12,16 @@ from litellm._logging import print_verbose
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ILLEGAL_DISPLAY_PARAMS = [ ILLEGAL_DISPLAY_PARAMS = ["messages", "api_key"]
"messages",
"api_key"
]
def _get_random_llm_message(): def _get_random_llm_message():
""" """
Get a random message from the LLM. Get a random message from the LLM.
""" """
messages = [ messages = ["Hey how's it going?", "What's 1 + 1?"]
"Hey how's it going?",
"What's 1 + 1?"
]
return [{"role": "user", "content": random.choice(messages)}]
return [
{"role": "user", "content": random.choice(messages)}
]
def _clean_litellm_params(litellm_params: dict): def _clean_litellm_params(litellm_params: dict):
@ -44,13 +35,16 @@ async def _perform_health_check(model_list: list):
""" """
Perform a health check for each model in the list. Perform a health check for each model in the list.
""" """
async def _check_img_gen_model(model_params: dict): async def _check_img_gen_model(model_params: dict):
model_params.pop("messages", None) model_params.pop("messages", None)
model_params["prompt"] = "test from litellm" model_params["prompt"] = "test from litellm"
try: try:
await litellm.aimage_generation(**model_params) await litellm.aimage_generation(**model_params)
except Exception as e: except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}") print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False return False
return True return True
@ -60,16 +54,19 @@ async def _perform_health_check(model_list: list):
try: try:
await litellm.aembedding(**model_params) await litellm.aembedding(**model_params)
except Exception as e: except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}") print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False return False
return True return True
async def _check_model(model_params: dict): async def _check_model(model_params: dict):
try: try:
await litellm.acompletion(**model_params) await litellm.acompletion(**model_params)
except Exception as e: except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}") print_verbose(
f"Health check failed for model {model_params['model']}. Error: {e}"
)
return False return False
return True return True
@ -104,9 +101,9 @@ async def _perform_health_check(model_list: list):
return healthy_endpoints, unhealthy_endpoints return healthy_endpoints, unhealthy_endpoints
async def perform_health_check(
model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None
async def perform_health_check(model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None): ):
""" """
Perform a health check on the system. Perform a health check on the system.
@ -115,7 +112,9 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl
""" """
if not model_list: if not model_list:
if cli_model: if cli_model:
model_list = [{"model_name": cli_model, "litellm_params": {"model": cli_model}}] model_list = [
{"model_name": cli_model, "litellm_params": {"model": cli_model}}
]
else: else:
return [], [] return [], []
@ -125,5 +124,3 @@ async def perform_health_check(model_list: list, model: Optional[str] = None, cl
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list) healthy_endpoints, unhealthy_endpoints = await _perform_health_check(model_list)
return healthy_endpoints, unhealthy_endpoints return healthy_endpoints, unhealthy_endpoints

View file

@ -6,6 +6,7 @@ from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
import json, traceback import json, traceback
class MaxBudgetLimiter(CustomLogger): class MaxBudgetLimiter(CustomLogger):
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
@ -15,7 +16,13 @@ class MaxBudgetLimiter(CustomLogger):
if litellm.set_verbose is True: if litellm.set_verbose is True:
print(print_statement) # noqa print(print_statement) # noqa
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str): async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try: try:
self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook") self.print_verbose(f"Inside Max Budget Limiter Pre-Call Hook")
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id" cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"

View file

@ -5,8 +5,10 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
class MaxParallelRequestsHandler(CustomLogger): class MaxParallelRequestsHandler(CustomLogger):
user_api_key_cache = None user_api_key_cache = None
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
pass pass
@ -15,8 +17,13 @@ class MaxParallelRequestsHandler(CustomLogger):
if litellm.set_verbose is True: if litellm.set_verbose is True:
print(print_statement) # noqa print(print_statement) # noqa
async def async_pre_call_hook(
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str): self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests max_parallel_requests = user_api_key_dict.max_parallel_requests
@ -39,8 +46,9 @@ class MaxParallelRequestsHandler(CustomLogger):
# Increase count for this token # Increase count for this token
cache.set_cache(request_count_api_key, int(current) + 1) cache.set_cache(request_count_api_key, int(current) + 1)
else: else:
raise HTTPException(status_code=429, detail="Max parallel request limit reached.") raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
@ -54,17 +62,23 @@ class MaxParallelRequestsHandler(CustomLogger):
request_count_api_key = f"{user_api_key}_request_count" request_count_api_key = f"{user_api_key}_request_count"
# check if it has collected an entire stream response # check if it has collected an entire stream response
self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}") self.print_verbose(
f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}"
)
if "complete_streaming_response" in kwargs or kwargs["stream"] != True: if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
# Decrease count for this token # Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 current = (
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
)
new_val = current - 1 new_val = current - 1
self.print_verbose(f"updated_value in success call: {new_val}") self.print_verbose(f"updated_value in success call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val) self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e: except Exception as e:
self.print_verbose(e) # noqa self.print_verbose(e) # noqa
async def async_log_failure_call(self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception): async def async_log_failure_call(
self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception
):
try: try:
self.print_verbose(f"Inside Max Parallel Request Failure Hook") self.print_verbose(f"Inside Max Parallel Request Failure Hook")
api_key = user_api_key_dict.api_key api_key = user_api_key_dict.api_key
@ -75,14 +89,18 @@ class MaxParallelRequestsHandler(CustomLogger):
return return
## decrement call count if call failed ## decrement call count if call failed
if (hasattr(original_exception, "status_code") if (
hasattr(original_exception, "status_code")
and original_exception.status_code == 429 and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)): and "Max parallel request limit reached" in str(original_exception)
):
pass # ignore failed calls due to max limit being reached pass # ignore failed calls due to max limit being reached
else: else:
request_count_api_key = f"{api_key}_request_count" request_count_api_key = f"{api_key}_request_count"
# Decrease count for this token # Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 current = (
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
)
new_val = current - 1 new_val = current - 1
self.print_verbose(f"updated_value in failure call: {new_val}") self.print_verbose(f"updated_value in failure call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val) self.user_api_key_cache.set_cache(request_count_api_key, new_val)

View file

@ -6,34 +6,42 @@ from datetime import datetime
import importlib import importlib
from dotenv import load_dotenv from dotenv import load_dotenv
import operator import operator
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
config_filename = "litellm.secrets" config_filename = "litellm.secrets"
# Using appdirs to determine user-specific config path # Using appdirs to determine user-specific config path
config_dir = appdirs.user_config_dir("litellm") config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)) user_config_path = os.getenv(
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
)
load_dotenv() load_dotenv()
from importlib import resources from importlib import resources
import shutil import shutil
telemetry = None telemetry = None
def run_ollama_serve(): def run_ollama_serve():
try: try:
command = ['ollama', 'serve'] command = ["ollama", "serve"]
with open(os.devnull, 'w') as devnull: with open(os.devnull, "w") as devnull:
process = subprocess.Popen(command, stdout=devnull, stderr=devnull) process = subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e: except Exception as e:
print(f""" print(
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""") # noqa """
) # noqa
def clone_subfolder(repo_url, subfolder, destination): def clone_subfolder(repo_url, subfolder, destination):
# Clone the full repo # Clone the full repo
repo_name = repo_url.split('/')[-1] repo_name = repo_url.split("/")[-1]
repo_master = os.path.join(destination, "repo_master") repo_master = os.path.join(destination, "repo_master")
subprocess.run(['git', 'clone', repo_url, repo_master]) subprocess.run(["git", "clone", repo_url, repo_master])
# Move into the subfolder # Move into the subfolder
subfolder_path = os.path.join(repo_master, subfolder) subfolder_path = os.path.join(repo_master, subfolder)
@ -48,43 +56,152 @@ def clone_subfolder(repo_url, subfolder, destination):
shutil.copytree(source, dest_path) shutil.copytree(source, dest_path)
# Remove cloned repo folder # Remove cloned repo folder
subprocess.run(['rm', '-rf', os.path.join(destination, "repo_master")]) subprocess.run(["rm", "-rf", os.path.join(destination, "repo_master")])
feature_telemetry(feature="create-proxy") feature_telemetry(feature="create-proxy")
def is_port_in_use(port): def is_port_in_use(port):
import socket import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0 return s.connect_ex(("localhost", port)) == 0
@click.command() @click.command()
@click.option('--host', default='0.0.0.0', help='Host for the server to listen on.') @click.option("--host", default="0.0.0.0", help="Host for the server to listen on.")
@click.option('--port', default=8000, help='Port to bind the server to.') @click.option("--port", default=8000, help="Port to bind the server to.")
@click.option('--num_workers', default=1, help='Number of uvicorn workers to spin up') @click.option("--num_workers", default=1, help="Number of uvicorn workers to spin up")
@click.option('--api_base', default=None, help='API base URL.') @click.option("--api_base", default=None, help="API base URL.")
@click.option('--api_version', default="2023-07-01-preview", help='For azure - pass in the api version.') @click.option(
@click.option('--model', '-m', default=None, help='The model name to pass to litellm expects') "--api_version",
@click.option('--alias', default=None, help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")') default="2023-07-01-preview",
@click.option('--add_key', default=None, help='The model name to pass to litellm expects') help="For azure - pass in the api version.",
@click.option('--headers', default=None, help='headers for the API call') )
@click.option('--save', is_flag=True, type=bool, help='Save the model-specific config') @click.option(
@click.option('--debug', default=False, is_flag=True, type=bool, help='To debug the input') "--model", "-m", default=None, help="The model name to pass to litellm expects"
@click.option('--use_queue', default=False, is_flag=True, type=bool, help='To use celery workers for async endpoints') )
@click.option('--temperature', default=None, type=float, help='Set temperature for the model') @click.option(
@click.option('--max_tokens', default=None, type=int, help='Set max tokens for the model') "--alias",
@click.option('--request_timeout', default=600, type=int, help='Set timeout in seconds for completion calls') default=None,
@click.option('--drop_params', is_flag=True, help='Drop any unmapped params') help='The alias for the model - use this to give a litellm model name (e.g. "huggingface/codellama/CodeLlama-7b-Instruct-hf") a more user-friendly name ("codellama")',
@click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt') )
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`') @click.option(
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`') "--add_key", default=None, help="The model name to pass to litellm expects"
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`') )
@click.option('--version', '-v', default=False, is_flag=True, type=bool, help='Print LiteLLM version') @click.option("--headers", default=None, help="headers for the API call")
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.') @click.option("--save", is_flag=True, type=bool, help="Save the model-specific config")
@click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml') @click.option(
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to') "--debug", default=False, is_flag=True, type=bool, help="To debug the input"
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response') )
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with') @click.option(
@click.option('--local', is_flag=True, default=False, help='for local debugging') "--use_queue",
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health, version): default=False,
is_flag=True,
type=bool,
help="To use celery workers for async endpoints",
)
@click.option(
"--temperature", default=None, type=float, help="Set temperature for the model"
)
@click.option(
"--max_tokens", default=None, type=int, help="Set max tokens for the model"
)
@click.option(
"--request_timeout",
default=600,
type=int,
help="Set timeout in seconds for completion calls",
)
@click.option("--drop_params", is_flag=True, help="Drop any unmapped params")
@click.option(
"--add_function_to_prompt",
is_flag=True,
help="If function passed but unsupported, pass it as prompt",
)
@click.option(
"--config",
"-c",
default=None,
help="Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`",
)
@click.option(
"--max_budget",
default=None,
type=float,
help="Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`",
)
@click.option(
"--telemetry",
default=True,
type=bool,
help="Helps us know if people are using this feature. Turn this off by doing `--telemetry False`",
)
@click.option(
"--version",
"-v",
default=False,
is_flag=True,
type=bool,
help="Print LiteLLM version",
)
@click.option(
"--logs",
flag_value=False,
type=int,
help='Gets the "n" most recent logs. By default gets most recent log.',
)
@click.option(
"--health",
flag_value=True,
help="Make a chat/completions request to all llms in config.yaml",
)
@click.option(
"--test",
flag_value=True,
help="proxy chat completions url to make a test request to",
)
@click.option(
"--test_async",
default=False,
is_flag=True,
help="Calls async endpoints /queue/requests and /queue/response",
)
@click.option(
"--num_requests",
default=10,
type=int,
help="Number of requests to hit async endpoint with",
)
@click.option("--local", is_flag=True, default=False, help="for local debugging")
def run_server(
host,
port,
api_base,
api_version,
model,
alias,
add_key,
headers,
save,
debug,
temperature,
max_tokens,
request_timeout,
drop_params,
add_function_to_prompt,
config,
max_budget,
telemetry,
logs,
test,
local,
num_workers,
test_async,
num_requests,
use_queue,
health,
version,
):
global feature_telemetry global feature_telemetry
args = locals() args = locals()
if local: if local:
@ -99,17 +216,23 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
if logs == 0: # default to 1 if logs == 0: # default to 1
logs = 1 logs = 1
try: try:
with open('api_log.json') as f: with open("api_log.json") as f:
data = json.load(f) data = json.load(f)
# convert keys to datetime objects # convert keys to datetime objects
log_times = {datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()} log_times = {
datetime.strptime(k, "%Y%m%d%H%M%S%f"): v for k, v in data.items()
}
# sort by timestamp # sort by timestamp
sorted_times = sorted(log_times.items(), key=operator.itemgetter(0), reverse=True) sorted_times = sorted(
log_times.items(), key=operator.itemgetter(0), reverse=True
)
# get n recent logs # get n recent logs
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]} recent_logs = {
k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]
}
print(json.dumps(recent_logs, indent=4)) # noqa print(json.dumps(recent_logs, indent=4)) # noqa
except: except:
@ -117,18 +240,21 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
return return
if version == True: if version == True:
pkg_version = importlib.metadata.version("litellm") pkg_version = importlib.metadata.version("litellm")
click.echo(f'\nLiteLLM: Current Version = {pkg_version}\n') click.echo(f"\nLiteLLM: Current Version = {pkg_version}\n")
return return
if model and "ollama" in model and api_base is None: if model and "ollama" in model and api_base is None:
run_ollama_serve() run_ollama_serve()
if test_async is True: if test_async is True:
import requests, concurrent, time import requests, concurrent, time
api_base = f"http://{host}:{port}" api_base = f"http://{host}:{port}"
def _make_openai_completion(): def _make_openai_completion():
data = { data = {
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Write a short poem about the moon"}] "messages": [
{"role": "user", "content": "Write a short poem about the moon"}
],
} }
response = requests.post("http://0.0.0.0:8000/queue/request", json=data) response = requests.post("http://0.0.0.0:8000/queue/request", json=data)
@ -146,7 +272,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
if status == "finished": if status == "finished":
llm_response = polling_response["result"] llm_response = polling_response["result"]
break break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa print(
f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}"
) # noqa
time.sleep(0.5) time.sleep(0.5)
except Exception as e: except Exception as e:
print("got exception in polling", e) print("got exception in polling", e)
@ -159,7 +287,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
futures = [] futures = []
start_time = time.time() start_time = time.time()
# Make concurrent calls # Make concurrent calls
with concurrent.futures.ThreadPoolExecutor(max_workers=concurrent_calls) as executor: with concurrent.futures.ThreadPoolExecutor(
max_workers=concurrent_calls
) as executor:
for _ in range(concurrent_calls): for _ in range(concurrent_calls):
futures.append(executor.submit(_make_openai_completion)) futures.append(executor.submit(_make_openai_completion))
@ -185,58 +315,86 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
return return
if health != False: if health != False:
import requests import requests
print("\nLiteLLM: Health Testing models in config") print("\nLiteLLM: Health Testing models in config")
response = requests.get(url=f"http://{host}:{port}/health") response = requests.get(url=f"http://{host}:{port}/health")
print(json.dumps(response.json(), indent=4)) print(json.dumps(response.json(), indent=4))
return return
if test != False: if test != False:
click.echo('\nLiteLLM: Making a test ChatCompletions request to your proxy') click.echo("\nLiteLLM: Making a test ChatCompletions request to your proxy")
import openai import openai
if test == True: # flag value set if test == True: # flag value set
api_base = f"http://{host}:{port}" api_base = f"http://{host}:{port}"
else: else:
api_base = test api_base = test
client = openai.OpenAI( client = openai.OpenAI(api_key="My API Key", base_url=api_base)
api_key="My API Key",
base_url=api_base
)
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [ response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{ {
"role": "user", "role": "user",
"content": "this is a test request, write a short poem" "content": "this is a test request, write a short poem",
} }
], max_tokens=256) ],
click.echo(f'\nLiteLLM: response from proxy {response}') max_tokens=256,
)
click.echo(f"\nLiteLLM: response from proxy {response}")
print("\n Making streaming request to proxy") print("\n Making streaming request to proxy")
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [ response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{ {
"role": "user", "role": "user",
"content": "this is a test request, write a short poem" "content": "this is a test request, write a short poem",
} }
], ],
stream=True, stream=True,
) )
for chunk in response: for chunk in response:
click.echo(f'LiteLLM: streaming response from proxy {chunk}') click.echo(f"LiteLLM: streaming response from proxy {chunk}")
print("\n making completion request to proxy") print("\n making completion request to proxy")
response = client.completions.create(model="gpt-3.5-turbo", prompt='this is a test request, write a short poem') response = client.completions.create(
model="gpt-3.5-turbo", prompt="this is a test request, write a short poem"
)
print(response) print(response)
return return
else: else:
if headers: if headers:
headers = json.loads(headers) headers = json.loads(headers)
save_worker_config(model=model, alias=alias, api_base=api_base, api_version=api_version, debug=debug, temperature=temperature, max_tokens=max_tokens, request_timeout=request_timeout, max_budget=max_budget, telemetry=telemetry, drop_params=drop_params, add_function_to_prompt=add_function_to_prompt, headers=headers, save=save, config=config, use_queue=use_queue) save_worker_config(
model=model,
alias=alias,
api_base=api_base,
api_version=api_version,
debug=debug,
temperature=temperature,
max_tokens=max_tokens,
request_timeout=request_timeout,
max_budget=max_budget,
telemetry=telemetry,
drop_params=drop_params,
add_function_to_prompt=add_function_to_prompt,
headers=headers,
save=save,
config=config,
use_queue=use_queue,
)
try: try:
import uvicorn import uvicorn
except: except:
raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`") raise ImportError(
"Uvicorn needs to be imported. Run - `pip install uvicorn`"
)
if port == 8000 and is_port_in_use(port): if port == 8000 and is_port_in_use(port):
port = random.randint(1024, 49152) port = random.randint(1024, 49152)
uvicorn.run("litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers) uvicorn.run(
"litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers
)
if __name__ == "__main__": if __name__ == "__main__":

File diff suppressed because it is too large Load diff

View file

@ -1,71 +1,77 @@
from dotenv import load_dotenv # from dotenv import load_dotenv
load_dotenv()
import json, subprocess
import psutil # Import the psutil library
import atexit
try:
### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both
from celery import Celery
import redis
except:
import sys
subprocess.check_call( # load_dotenv()
[ # import json, subprocess
sys.executable, # import psutil # Import the psutil library
"-m", # import atexit
"pip",
"install",
"redis",
"celery"
]
)
import time # try:
import sys, os # ### OPTIONAL DEPENDENCIES ### - pip install redis and celery only when a user opts into using the async endpoints which require both
sys.path.insert( # from celery import Celery
0, os.path.abspath("../../..") # import redis
) # Adds the parent directory to the system path - for litellm local dev # except:
import litellm # import sys
# Redis connection setup # subprocess.check_call([sys.executable, "-m", "pip", "install", "redis", "celery"])
pool = redis.ConnectionPool(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"), db=0, max_connections=5)
redis_client = redis.Redis(connection_pool=pool)
# Celery setup # import time
celery_app = Celery('tasks', broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}", backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}") # import sys, os
celery_app.conf.update(
broker_pool_limit = None, # sys.path.insert(
broker_transport_options = {'connection_pool': pool}, # 0, os.path.abspath("../../..")
result_backend_transport_options = {'connection_pool': pool}, # ) # Adds the parent directory to the system path - for litellm local dev
) # import litellm
# # Redis connection setup
# pool = redis.ConnectionPool(
# host=os.getenv("REDIS_HOST"),
# port=os.getenv("REDIS_PORT"),
# password=os.getenv("REDIS_PASSWORD"),
# db=0,
# max_connections=5,
# )
# redis_client = redis.Redis(connection_pool=pool)
# # Celery setup
# celery_app = Celery(
# "tasks",
# broker=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}",
# backend=f"redis://default:{os.getenv('REDIS_PASSWORD')}@{os.getenv('REDIS_HOST')}:{os.getenv('REDIS_PORT')}",
# )
# celery_app.conf.update(
# broker_pool_limit=None,
# broker_transport_options={"connection_pool": pool},
# result_backend_transport_options={"connection_pool": pool},
# )
# Celery task # # Celery task
@celery_app.task(name='process_job', max_retries=3) # @celery_app.task(name="process_job", max_retries=3)
def process_job(*args, **kwargs): # def process_job(*args, **kwargs):
try: # try:
llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore # llm_router: litellm.Router = litellm.Router(model_list=kwargs.pop("llm_model_list")) # type: ignore
response = llm_router.completion(*args, **kwargs) # type: ignore # response = llm_router.completion(*args, **kwargs) # type: ignore
if isinstance(response, litellm.ModelResponse): # if isinstance(response, litellm.ModelResponse):
response = response.model_dump_json() # response = response.model_dump_json()
return json.loads(response) # return json.loads(response)
return str(response) # return str(response)
except Exception as e: # except Exception as e:
raise e # raise e
# Ensure Celery workers are terminated when the script exits
def cleanup():
try:
# Get a list of all running processes
for process in psutil.process_iter(attrs=['pid', 'name']):
# Check if the process is a Celery worker process
if process.info['name'] == 'celery':
print(f"Terminating Celery worker with PID {process.info['pid']}")
# Terminate the Celery worker process
psutil.Process(process.info['pid']).terminate()
except Exception as e:
print(f"Error during cleanup: {e}")
# Register the cleanup function to run when the script exits # # Ensure Celery workers are terminated when the script exits
atexit.register(cleanup) # def cleanup():
# try:
# # Get a list of all running processes
# for process in psutil.process_iter(attrs=["pid", "name"]):
# # Check if the process is a Celery worker process
# if process.info["name"] == "celery":
# print(f"Terminating Celery worker with PID {process.info['pid']}")
# # Terminate the Celery worker process
# psutil.Process(process.info["pid"]).terminate()
# except Exception as e:
# print(f"Error during cleanup: {e}")
# # Register the cleanup function to run when the script exits
# atexit.register(cleanup)

View file

@ -1,12 +1,15 @@
import os import os
from multiprocessing import Process from multiprocessing import Process
def run_worker(cwd): def run_worker(cwd):
os.chdir(cwd) os.chdir(cwd)
os.system("celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info") os.system(
"celery -A celery_app.celery_app worker --concurrency=120 --loglevel=info"
)
def start_worker(cwd): def start_worker(cwd):
cwd += "/queue" cwd += "/queue"
worker_process = Process(target=run_worker, args=(cwd,)) worker_process = Process(target=run_worker, args=(cwd,))
worker_process.start() worker_process.start()

View file

@ -1,26 +1,34 @@
import sys, os # import sys, os
from dotenv import load_dotenv # from dotenv import load_dotenv
load_dotenv()
# Add the path to the local folder to sys.path
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path - for litellm local dev
def start_rq_worker(): # load_dotenv()
from rq import Worker, Queue, Connection # # Add the path to the local folder to sys.path
from redis import Redis # sys.path.insert(
# Set up RQ connection # 0, os.path.abspath("../../..")
redis_conn = Redis(host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD")) # ) # Adds the parent directory to the system path - for litellm local dev
print(redis_conn.ping()) # Should print True if connected successfully
# Create a worker and add the queue
try:
queue = Queue(connection=redis_conn)
worker = Worker([queue], connection=redis_conn)
except Exception as e:
print(f"Error setting up worker: {e}")
exit()
with Connection(redis_conn):
worker.work()
start_rq_worker() # def start_rq_worker():
# from rq import Worker, Queue, Connection
# from redis import Redis
# # Set up RQ connection
# redis_conn = Redis(
# host=os.getenv("REDIS_HOST"),
# port=os.getenv("REDIS_PORT"),
# password=os.getenv("REDIS_PASSWORD"),
# )
# print(redis_conn.ping()) # Should print True if connected successfully
# # Create a worker and add the queue
# try:
# queue = Queue(connection=redis_conn)
# worker = Worker([queue], connection=redis_conn)
# except Exception as e:
# print(f"Error setting up worker: {e}")
# exit()
# with Connection(redis_conn):
# worker.work()
# start_rq_worker()

View file

@ -4,10 +4,7 @@ import uuid
import traceback import traceback
litellm_client = AsyncOpenAI( litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000")
api_key="test",
base_url="http://0.0.0.0:8000"
)
async def litellm_completion(): async def litellm_completion():
@ -15,9 +12,10 @@ async def litellm_completion():
try: try:
response = await litellm_client.chat.completions.create( response = await litellm_client.chat.completions.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"*180}], # this is about 4k tokens per request messages=[
{"role": "user", "content": f"This is a test: {uuid.uuid4()}" * 180}
], # this is about 4k tokens per request
) )
print(response)
return response return response
except Exception as e: except Exception as e:
@ -27,7 +25,6 @@ async def litellm_completion():
pass pass
async def main(): async def main():
start = time.time() start = time.time()
n = 60 # Send 60 concurrent requests, each with 4k tokens = 240k Tokens n = 60 # Send 60 concurrent requests, each with 4k tokens = 240k Tokens
@ -45,6 +42,7 @@ async def main():
print(n, time.time() - start, len(successful_completions)) print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__": if __name__ == "__main__":
# Blank out contents of error_log.txt # Blank out contents of error_log.txt
open("error_log.txt", "w").close() open("error_log.txt", "w").close()

View file

@ -4,10 +4,7 @@ import uuid
import traceback import traceback
litellm_client = AsyncOpenAI( litellm_client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:8000")
api_key="sk-1234",
base_url="http://0.0.0.0:8000"
)
async def litellm_completion(): async def litellm_completion():
@ -26,7 +23,6 @@ async def litellm_completion():
pass pass
async def main(): async def main():
start = time.time() start = time.time()
n = 1000 # Number of concurrent tasks n = 1000 # Number of concurrent tasks
@ -44,6 +40,7 @@ async def main():
print(n, time.time() - start, len(successful_completions)) print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__": if __name__ == "__main__":
# Blank out contents of error_log.txt # Blank out contents of error_log.txt
open("error_log.txt", "w").close() open("error_log.txt", "w").close()

View file

@ -14,8 +14,8 @@ import pytest
import litellm import litellm
litellm.set_verbose=False
litellm.set_verbose = False
question = "embed this very long text" * 100 question = "embed this very long text" * 100
@ -35,7 +35,10 @@ def make_openai_completion(question):
try: try:
start_time = time.time() start_time = time.time()
import openai import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY']) #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"]
) # base_url="http://0.0.0.0:8000",
response = client.embeddings.create( response = client.embeddings.create(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=[question], input=[question],
@ -58,6 +61,7 @@ def make_openai_completion(question):
# ) # )
return None return None
start_time = time.time() start_time = time.time()
# Number of concurrent calls (you can adjust this) # Number of concurrent calls (you can adjust this)
concurrent_calls = 500 concurrent_calls = 500

View file

@ -4,11 +4,7 @@ import uuid
import traceback import traceback
litellm_client = AsyncOpenAI( litellm_client = AsyncOpenAI(api_key="test", base_url="http://0.0.0.0:8000")
api_key="test",
base_url="http://0.0.0.0:8000"
)
async def litellm_completion(): async def litellm_completion():
@ -17,11 +13,11 @@ async def litellm_completion():
print("starting embedding calls") print("starting embedding calls")
response = await litellm_client.embeddings.create( response = await litellm_client.embeddings.create(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input = [ input=[
"hello who are you" * 2000, "hello who are you" * 2000,
"hello who are you tomorrow 1234" * 1000, "hello who are you tomorrow 1234" * 1000,
"hello who are you tomorrow 1234" * 1000 "hello who are you tomorrow 1234" * 1000,
] ],
) )
print(response) print(response)
return response return response
@ -33,7 +29,6 @@ async def litellm_completion():
pass pass
async def main(): async def main():
start = time.time() start = time.time()
n = 100 # Number of concurrent tasks n = 100 # Number of concurrent tasks
@ -51,6 +46,7 @@ async def main():
print(n, time.time() - start, len(successful_completions)) print(n, time.time() - start, len(successful_completions))
if __name__ == "__main__": if __name__ == "__main__":
# Blank out contents of error_log.txt # Blank out contents of error_log.txt
open("error_log.txt", "w").close() open("error_log.txt", "w").close()

View file

@ -14,8 +14,8 @@ import pytest
import litellm import litellm
litellm.set_verbose=False
litellm.set_verbose = False
question = "embed this very long text" * 100 question = "embed this very long text" * 100
@ -35,7 +35,10 @@ def make_openai_completion(question):
try: try:
start_time = time.time() start_time = time.time()
import openai import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="http://0.0.0.0:8000") #base_url="http://0.0.0.0:8000",
client = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"], base_url="http://0.0.0.0:8000"
) # base_url="http://0.0.0.0:8000",
response = client.embeddings.create( response = client.embeddings.create(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=[question], input=[question],
@ -58,6 +61,7 @@ def make_openai_completion(question):
# ) # )
return None return None
start_time = time.time() start_time = time.time()
# Number of concurrent calls (you can adjust this) # Number of concurrent calls (you can adjust this)
concurrent_calls = 500 concurrent_calls = 500

View file

@ -2,6 +2,7 @@ import requests
import time import time
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -17,32 +18,30 @@ config = {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { "litellm_params": {
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"api_key": os.environ['OPENAI_API_KEY'], "api_key": os.environ["OPENAI_API_KEY"],
} },
}, },
{ {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { "litellm_params": {
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.environ['AZURE_API_KEY'], "api_key": os.environ["AZURE_API_KEY"],
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/", "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_version": "2023-07-01-preview" "api_version": "2023-07-01-preview",
} },
} },
] ]
} }
print("STARTING LOAD TEST Q") print("STARTING LOAD TEST Q")
print(os.environ['AZURE_API_KEY']) print(os.environ["AZURE_API_KEY"])
response = requests.post( response = requests.post(
url=f"{base_url}/key/generate", url=f"{base_url}/key/generate",
json={ json={
"config": config, "config": config,
"duration": "30d" # default to 30d, set it to 30m if you want a temp key "duration": "30d", # default to 30d, set it to 30m if you want a temp key
}, },
headers={ headers={"Authorization": "Bearer sk-hosted-litellm"},
"Authorization": "Bearer sk-hosted-litellm"
}
) )
print("\nresponse from generating key", response.text) print("\nresponse from generating key", response.text)
@ -56,19 +55,18 @@ print("\ngenerated key for proxy", generated_key)
import concurrent.futures import concurrent.futures
def create_job_and_poll(request_num): def create_job_and_poll(request_num):
print(f"Creating a job on the proxy for request {request_num}") print(f"Creating a job on the proxy for request {request_num}")
job_response = requests.post( job_response = requests.post(
url=f"{base_url}/queue/request", url=f"{base_url}/queue/request",
json={ json={
'model': 'gpt-3.5-turbo', "model": "gpt-3.5-turbo",
'messages': [ "messages": [
{'role': 'system', 'content': 'write a short poem'}, {"role": "system", "content": "write a short poem"},
], ],
}, },
headers={ headers={"Authorization": f"Bearer {generated_key}"},
"Authorization": f"Bearer {generated_key}"
}
) )
print(job_response.status_code) print(job_response.status_code)
print(job_response.text) print(job_response.text)
@ -84,12 +82,12 @@ def create_job_and_poll(request_num):
try: try:
print(f"\nPolling URL for request {request_num}", polling_url) print(f"\nPolling URL for request {request_num}", polling_url)
polling_response = requests.get( polling_response = requests.get(
url=polling_url, url=polling_url, headers={"Authorization": f"Bearer {generated_key}"}
headers={ )
"Authorization": f"Bearer {generated_key}" print(
} f"\nResponse from polling url for request {request_num}",
polling_response.text,
) )
print(f"\nResponse from polling url for request {request_num}", polling_response.text)
polling_response = polling_response.json() polling_response = polling_response.json()
status = polling_response.get("status", None) status = polling_response.get("status", None)
if status == "finished": if status == "finished":
@ -109,6 +107,7 @@ def create_job_and_poll(request_num):
except Exception as e: except Exception as e:
print("got exception when polling", e) print("got exception when polling", e)
# Number of requests # Number of requests
num_requests = 100 num_requests = 100

View file

@ -26,4 +26,3 @@
# import asyncio # import asyncio
# asyncio.run(test_async_completion()) # asyncio.run(test_async_completion())

View file

@ -34,7 +34,3 @@
# response = claude_chat(messages) # response = claude_chat(messages)
# print(response) # print(response)

View file

@ -2,6 +2,7 @@ import requests
import time import time
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -17,8 +18,8 @@ config = {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-3.5-turbo",
"litellm_params": { "litellm_params": {
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"api_key": os.environ['OPENAI_API_KEY'], "api_key": os.environ["OPENAI_API_KEY"],
} },
} }
] ]
} }
@ -27,11 +28,9 @@ response = requests.post(
url=f"{base_url}/key/generate", url=f"{base_url}/key/generate",
json={ json={
"config": config, "config": config,
"duration": "30d" # default to 30d, set it to 30m if you want a temp key "duration": "30d", # default to 30d, set it to 30m if you want a temp key
}, },
headers={ headers={"Authorization": "Bearer sk-hosted-litellm"},
"Authorization": "Bearer sk-hosted-litellm"
}
) )
print("\nresponse from generating key", response.text) print("\nresponse from generating key", response.text)
@ -45,14 +44,15 @@ print("Creating a job on the proxy")
job_response = requests.post( job_response = requests.post(
url=f"{base_url}/queue/request", url=f"{base_url}/queue/request",
json={ json={
'model': 'gpt-3.5-turbo', "model": "gpt-3.5-turbo",
'messages': [ "messages": [
{'role': 'system', 'content': f'You are a helpful assistant. What is your name'}, {
"role": "system",
"content": f"You are a helpful assistant. What is your name",
},
], ],
}, },
headers={ headers={"Authorization": f"Bearer {generated_key}"},
"Authorization": f"Bearer {generated_key}"
}
) )
print(job_response.status_code) print(job_response.status_code)
print(job_response.text) print(job_response.text)
@ -68,10 +68,7 @@ while True:
try: try:
print("\nPolling URL", polling_url) print("\nPolling URL", polling_url)
polling_response = requests.get( polling_response = requests.get(
url=polling_url, url=polling_url, headers={"Authorization": f"Bearer {generated_key}"}
headers={
"Authorization": f"Bearer {generated_key}"
}
) )
print("\nResponse from polling url", polling_response.text) print("\nResponse from polling url", polling_response.text)
polling_response = polling_response.json() polling_response = polling_response.json()

View file

@ -8,9 +8,12 @@ from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException, status from fastapi import HTTPException, status
def print_verbose(print_statement): def print_verbose(print_statement):
if litellm.set_verbose: if litellm.set_verbose:
print(f"LiteLLM Proxy: {print_statement}") # noqa print(f"LiteLLM Proxy: {print_statement}") # noqa
### LOGGING ### ### LOGGING ###
class ProxyLogging: class ProxyLogging:
""" """
@ -57,11 +60,14 @@ class ProxyLogging:
+ litellm.failure_callback + litellm.failure_callback
) )
) )
litellm.utils.set_callbacks( litellm.utils.set_callbacks(callback_list=callback_list)
callback_list=callback_list
)
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]): async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["completion", "embeddings"],
):
""" """
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
@ -71,12 +77,19 @@ class ProxyLogging:
""" """
try: try:
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__): if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type) callback.__class__
):
response = await callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
data=data,
call_type=call_type,
)
if response is not None: if response is not None:
data = response data = response
print_verbose(f'final data being sent to {call_type} call: {data}') print_verbose(f"final data being sent to {call_type} call: {data}")
return data return data
except Exception as e: except Exception as e:
raise e raise e
@ -96,7 +109,9 @@ class ProxyLogging:
if litellm.utils.capture_exception: if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception) litellm.utils.capture_exception(error=original_exception)
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth): async def post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
""" """
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body. Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
@ -108,7 +123,10 @@ class ProxyLogging:
for callback in litellm.callbacks: for callback in litellm.callbacks:
try: try:
if isinstance(callback, CustomLogger): if isinstance(callback, CustomLogger):
await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception) await callback.async_post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=original_exception,
)
except Exception as e: except Exception as e:
raise e raise e
return return
@ -121,9 +139,12 @@ def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries # The 'tries' key in the details dictionary contains the number of completed tries
print_verbose(f"Backing off... this was attempt #{details['tries']}") print_verbose(f"Backing off... this was attempt #{details['tries']}")
class PrismaClient: class PrismaClient:
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
## init logging object ## init logging object
self.proxy_logging_obj = proxy_logging_obj self.proxy_logging_obj = proxy_logging_obj
@ -136,15 +157,16 @@ class PrismaClient:
os.chdir(dname) os.chdir(dname)
try: try:
subprocess.run(['prisma', 'generate']) subprocess.run(["prisma", "generate"])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss subprocess.run(
["prisma", "db", "push", "--accept-data-loss"]
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
# Now you can import the Prisma Client # Now you can import the Prisma Client
from prisma import Client # type: ignore from prisma import Client # type: ignore
self.db = Client() #Client to connect to Prisma db
self.db = Client() # Client to connect to Prisma db
def hash_token(self, token: str): def hash_token(self, token: str):
# Hash the string using SHA-256 # Hash the string using SHA-256
@ -167,7 +189,12 @@ class PrismaClient:
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def get_data(self, token: Optional[str]=None, expires: Optional[Any]=None, user_id: Optional[str]=None): async def get_data(
self,
token: Optional[str] = None,
expires: Optional[Any] = None,
user_id: Optional[str] = None,
):
try: try:
response = None response = None
if token is not None: if token is not None:
@ -176,9 +203,7 @@ class PrismaClient:
if token.startswith("sk-"): if token.startswith("sk-"):
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
response = await self.db.litellm_verificationtoken.find_unique( response = await self.db.litellm_verificationtoken.find_unique(
where={ where={"token": hashed_token}
"token": hashed_token
}
) )
if response: if response:
# Token exists, now check expiration. # Token exists, now check expiration.
@ -188,11 +213,17 @@ class PrismaClient:
return response return response
else: else:
# Token exists but is expired. # Token exists but is expired.
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="expired user key") raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="expired user key",
)
return response return response
else: else:
# Token does not exist. # Token does not exist.
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid user key",
)
elif user_id is not None: elif user_id is not None:
response = await self.db.litellm_usertable.find_unique( # type: ignore response = await self.db.litellm_usertable.find_unique( # type: ignore
where={ where={
@ -201,7 +232,9 @@ class PrismaClient:
) )
return response return response
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e raise e
# Define a retrying strategy with exponential backoff # Define a retrying strategy with exponential backoff
@ -224,26 +257,26 @@ class PrismaClient:
max_budget = db_data.pop("max_budget", None) max_budget = db_data.pop("max_budget", None)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={ where={
'token': hashed_token, "token": hashed_token,
}, },
data={ data={
"create": {**db_data}, #type: ignore "create": {**db_data}, # type: ignore
"update": {} # don't do anything if it already exists "update": {}, # don't do anything if it already exists
} },
) )
new_user_row = await self.db.litellm_usertable.upsert( new_user_row = await self.db.litellm_usertable.upsert(
where={ where={"user_id": data["user_id"]},
'user_id': data['user_id']
},
data={ data={
"create": {"user_id": data['user_id'], "max_budget": max_budget}, "create": {"user_id": data["user_id"], "max_budget": max_budget},
"update": {} # don't do anything if it already exists "update": {}, # don't do anything if it already exists
} },
) )
return new_verification_token return new_verification_token
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e raise e
# Define a retrying strategy with exponential backoff # Define a retrying strategy with exponential backoff
@ -254,7 +287,12 @@ class PrismaClient:
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def update_data(self, token: Optional[str]=None, data: dict={}, user_id: Optional[str]=None): async def update_data(
self,
token: Optional[str] = None,
data: dict = {},
user_id: Optional[str] = None,
):
""" """
Update existing data Update existing data
""" """
@ -267,10 +305,8 @@ class PrismaClient:
token = self.hash_token(token=token) token = self.hash_token(token=token)
db_data["token"] = token db_data["token"] = token
response = await self.db.litellm_verificationtoken.update( response = await self.db.litellm_verificationtoken.update(
where={ where={"token": token}, # type: ignore
"token": token # type: ignore data={**db_data}, # type: ignore
},
data={**db_data} # type: ignore
) )
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": db_data} return {"token": token, "data": db_data}
@ -279,18 +315,17 @@ class PrismaClient:
If data['spend'] + data['user'], update the user table with spend info as well If data['spend'] + data['user'], update the user table with spend info as well
""" """
update_user_row = await self.db.litellm_usertable.update( update_user_row = await self.db.litellm_usertable.update(
where={ where={"user_id": user_id}, # type: ignore
'user_id': user_id # type: ignore data={**db_data}, # type: ignore
},
data={**db_data} # type: ignore
) )
return {"user_id": user_id, "data": db_data} return {"user_id": user_id, "data": db_data}
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
raise e raise e
# Define a retrying strategy with exponential backoff # Define a retrying strategy with exponential backoff
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
@ -310,7 +345,9 @@ class PrismaClient:
) )
return {"deleted_keys": tokens} return {"deleted_keys": tokens}
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e raise e
# Define a retrying strategy with exponential backoff # Define a retrying strategy with exponential backoff
@ -325,7 +362,9 @@ class PrismaClient:
try: try:
await self.db.connect() await self.db.connect()
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e raise e
# Define a retrying strategy with exponential backoff # Define a retrying strategy with exponential backoff
@ -340,9 +379,12 @@ class PrismaClient:
try: try:
await self.db.disconnect() await self.db.disconnect()
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e raise e
### CUSTOM FILE ### ### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try: try:
@ -357,12 +399,14 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
# If config_file_path is provided, use it to determine the module spec and load the module # If config_file_path is provided, use it to determine the module spec and load the module
if config_file_path is not None: if config_file_path is not None:
directory = os.path.dirname(config_file_path) directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, *module_name.split('.')) module_file_path = os.path.join(directory, *module_name.split("."))
module_file_path += '.py' module_file_path += ".py"
spec = importlib.util.spec_from_file_location(module_name, module_file_path) spec = importlib.util.spec_from_file_location(module_name, module_file_path)
if spec is None: if spec is None:
raise ImportError(f"Could not find a module specification for {module_file_path}") raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore spec.loader.exec_module(module) # type: ignore
else: else:
@ -379,6 +423,7 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e: except Exception as e:
raise e raise e
### HELPER FUNCTIONS ### ### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
""" """
@ -390,5 +435,7 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
if response is None: # Cache miss if response is None: # Cache miss
user_row = await db.get_data(user_id=user_id) user_row = await db.get_data(user_id=user_id)
cache_value = user_row.model_dump_json() cache_value = user_row.model_dump_json()
cache.set_cache(key=cache_key, value=cache_value, ttl=600) # store for 10 minutes cache.set_cache(
key=cache_key, value=cache_value, ttl=600
) # store for 10 minutes
return return

File diff suppressed because it is too large Load diff

View file

@ -8,18 +8,18 @@
import dotenv, os, requests import dotenv, os, requests
from typing import Optional from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
class LeastBusyLoggingHandler(CustomLogger):
class LeastBusyLoggingHandler(CustomLogger):
def __init__(self, router_cache: DualCache): def __init__(self, router_cache: DualCache):
self.router_cache = router_cache self.router_cache = router_cache
self.mapping_deployment_to_id: dict = {} self.mapping_deployment_to_id: dict = {}
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
""" """
Log when a model is being used. Log when a model is being used.
@ -27,13 +27,16 @@ class LeastBusyLoggingHandler(CustomLogger):
Caching based on model group. Caching based on model group.
""" """
try: try:
if kwargs["litellm_params"].get("metadata") is None:
if kwargs['litellm_params'].get('metadata') is None:
pass pass
else: else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None) deployment = kwargs["litellm_params"]["metadata"].get(
model_group = kwargs['litellm_params']['metadata'].get('model_group', None) "deployment", None
id = kwargs['litellm_params'].get('model_info', {}).get('id', None) )
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if deployment is None or model_group is None or id is None: if deployment is None or model_group is None or id is None:
return return
@ -42,53 +45,75 @@ class LeastBusyLoggingHandler(CustomLogger):
request_count_api_key = f"{model_group}_request_count" request_count_api_key = f"{model_group}_request_count"
# update cache # update cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} request_count_dict = (
request_count_dict[deployment] = request_count_dict.get(deployment, 0) + 1 self.router_cache.get_cache(key=request_count_api_key) or {}
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) )
request_count_dict[deployment] = (
request_count_dict.get(deployment, 0) + 1
)
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e: except Exception as e:
pass pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
if kwargs['litellm_params'].get('metadata') is None: if kwargs["litellm_params"].get("metadata") is None:
pass pass
else: else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None) deployment = kwargs["litellm_params"]["metadata"].get(
model_group = kwargs['litellm_params']['metadata'].get('model_group', None) "deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
if deployment is None or model_group is None: if deployment is None or model_group is None:
return return
request_count_api_key = f"{model_group}_request_count" request_count_api_key = f"{model_group}_request_count"
# decrement count in cache # decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = request_count_dict.get(deployment) request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e: except Exception as e:
pass pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try: try:
if kwargs['litellm_params'].get('metadata') is None: if kwargs["litellm_params"].get("metadata") is None:
pass pass
else: else:
deployment = kwargs['litellm_params']['metadata'].get('deployment', None) deployment = kwargs["litellm_params"]["metadata"].get(
model_group = kwargs['litellm_params']['metadata'].get('model_group', None) "deployment", None
)
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
if deployment is None or model_group is None: if deployment is None or model_group is None:
return return
request_count_api_key = f"{model_group}_request_count" request_count_api_key = f"{model_group}_request_count"
# decrement count in cache # decrement count in cache
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[deployment] = request_count_dict.get(deployment) request_count_dict[deployment] = request_count_dict.get(deployment)
self.router_cache.set_cache(key=request_count_api_key, value=request_count_dict) self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception as e: except Exception as e:
pass pass
def get_available_deployments(self, model_group: str): def get_available_deployments(self, model_group: str):
request_count_api_key = f"{model_group}_request_count" request_count_api_key = f"{model_group}_request_count"
request_count_dict = self.router_cache.get_cache(key=request_count_api_key) or {} request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
# map deployment to id # map deployment to id
return_dict = {} return_dict = {}
for key, value in request_count_dict.items(): for key, value in request_count_dict.items():

View file

@ -2,6 +2,7 @@
import pytest, sys, os import pytest, sys, os
import importlib import importlib
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
@ -14,17 +15,23 @@ def setup_and_teardown():
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
""" """
curr_dir = os.getcwd() # Get the current working directory curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
import litellm import litellm
importlib.reload(litellm) importlib.reload(litellm)
print(litellm) print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding # from litellm import Router, completion, aembedding, acompletion, embedding
yield yield
def pytest_collection_modifyitems(config, items): def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name] custom_logger_tests = [
other_tests = [item for item in items if 'custom_logger' not in item.parent.name] item for item in items if "custom_logger" in item.parent.name
]
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
# Sort tests based on their names # Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name) custom_logger_tests.sort(key=lambda x: x.name)

View file

@ -4,6 +4,7 @@
import sys, os, time import sys, os, time
import traceback, asyncio import traceback, asyncio
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
@ -11,15 +12,17 @@ import litellm
from litellm import Router from litellm import Router
import concurrent import concurrent
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
model_list = [{ # list of model deployments model_list = [
{ # list of model deployments
"model_name": "gpt-3.5-turbo", # openai model name "model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call "litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": "bad-key", "api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800, "rpm": 1800,
@ -31,24 +34,30 @@ model_list = [{ # list of model deployments
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
}, },
"tpm": 1000000, "tpm": 1000000,
"rpm": 9000 "rpm": 9000,
} },
] ]
kwargs = {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}],} kwargs = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
}
def test_multiple_deployments_sync(): def test_multiple_deployments_sync():
import concurrent, time import concurrent, time
litellm.set_verbose=False
litellm.set_verbose = False
results = [] results = []
router = Router(model_list=model_list, router = Router(
model_list=model_list,
redis_host=os.getenv("REDIS_HOST"), redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"), redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
routing_strategy="simple-shuffle", routing_strategy="simple-shuffle",
set_verbose=True, set_verbose=True,
num_retries=1) # type: ignore num_retries=1,
) # type: ignore
try: try:
for _ in range(3): for _ in range(3):
response = router.completion(**kwargs) response = router.completion(**kwargs)
@ -59,6 +68,7 @@ def test_multiple_deployments_sync():
print(f"FAILED TEST!") print(f"FAILED TEST!")
pytest.fail(f"An error occurred - {traceback.format_exc()}") pytest.fail(f"An error occurred - {traceback.format_exc()}")
# test_multiple_deployments_sync() # test_multiple_deployments_sync()
@ -67,13 +77,15 @@ def test_multiple_deployments_parallel():
results = [] results = []
futures = {} futures = {}
start_time = time.time() start_time = time.time()
router = Router(model_list=model_list, router = Router(
model_list=model_list,
redis_host=os.getenv("REDIS_HOST"), redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"), redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")), # type: ignore redis_port=int(os.getenv("REDIS_PORT")), # type: ignore
routing_strategy="simple-shuffle", routing_strategy="simple-shuffle",
set_verbose=True, set_verbose=True,
num_retries=1) # type: ignore num_retries=1,
) # type: ignore
# Assuming you have an executor instance defined somewhere in your code # Assuming you have an executor instance defined somewhere in your code
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
for _ in range(5): for _ in range(5):
@ -82,7 +94,11 @@ def test_multiple_deployments_parallel():
# Retrieve the results from the futures # Retrieve the results from the futures
while futures: while futures:
done, not_done = concurrent.futures.wait(futures.values(), timeout=10, return_when=concurrent.futures.FIRST_COMPLETED) done, not_done = concurrent.futures.wait(
futures.values(),
timeout=10,
return_when=concurrent.futures.FIRST_COMPLETED,
)
for future in done: for future in done:
try: try:
result = future.result() result = future.result()
@ -98,8 +114,10 @@ def test_multiple_deployments_parallel():
print(results) print(results)
print(f"ELAPSED TIME: {end_time - start_time}") print(f"ELAPSED TIME: {end_time - start_time}")
# Assuming litellm, router, and executor are defined somewhere in your code # Assuming litellm, router, and executor are defined somewhere in your code
# test_multiple_deployments_parallel() # test_multiple_deployments_parallel()
def test_cooldown_same_model_name(): def test_cooldown_same_model_name():
# users could have the same model with different api_base # users could have the same model with different api_base
@ -118,7 +136,7 @@ def test_cooldown_same_model_name():
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": "BAD_API_BASE", "api_base": "BAD_API_BASE",
"tpm": 90 "tpm": 90,
}, },
}, },
{ {
@ -128,7 +146,7 @@ def test_cooldown_same_model_name():
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
"tpm": 0.000001 "tpm": 0.000001,
}, },
}, },
] ]
@ -140,17 +158,12 @@ def test_cooldown_same_model_name():
redis_port=int(os.getenv("REDIS_PORT")), redis_port=int(os.getenv("REDIS_PORT")),
routing_strategy="simple-shuffle", routing_strategy="simple-shuffle",
set_verbose=True, set_verbose=True,
num_retries=3 num_retries=3,
) # type: ignore ) # type: ignore
response = router.completion( response = router.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[ messages=[{"role": "user", "content": "hello this request will pass"}],
{
"role": "user",
"content": "hello this request will pass"
}
]
) )
print(router.model_list) print(router.model_list)
model_ids = [] model_ids = []
@ -159,10 +172,13 @@ def test_cooldown_same_model_name():
print("\n litellm model ids ", model_ids) print("\n litellm model ids ", model_ids)
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960'] # example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
assert model_ids[0] != model_ids[1] # ensure both models have a uuid added, and they have different names assert (
model_ids[0] != model_ids[1]
) # ensure both models have a uuid added, and they have different names
print("\ngot response\n", response) print("\ngot response\n", response)
except Exception as e: except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {e}") pytest.fail(f"Got unexpected exception on router! - {e}")
test_cooldown_same_model_name() test_cooldown_same_model_name()

View file

@ -9,11 +9,12 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
## case 1: set_function_to_prompt not set ## case 1: set_function_to_prompt not set
def test_function_call_non_openai_model(): def test_function_call_non_openai_model():
try: try:
model = "claude-instant-1" model = "claude-instant-1"
messages=[{"role": "user", "content": "what's the weather in sf?"}] messages = [{"role": "user", "content": "what's the weather in sf?"}]
functions = [ functions = [
{ {
"name": "get_current_weather", "name": "get_current_weather",
@ -23,31 +24,32 @@ def test_function_call_non_openai_model():
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA" "description": "The city and state, e.g. San Francisco, CA",
}, },
"unit": { "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
"type": "string", },
"enum": ["celsius", "fahrenheit"] "required": ["location"],
}
}, },
"required": ["location"]
}
} }
] ]
response = litellm.completion(model=model, messages=messages, functions=functions) response = litellm.completion(
pytest.fail(f'An error occurred') model=model, messages=messages, functions=functions
)
pytest.fail(f"An error occurred")
except Exception as e: except Exception as e:
print(e) print(e)
pass pass
test_function_call_non_openai_model() test_function_call_non_openai_model()
## case 2: add_function_to_prompt set ## case 2: add_function_to_prompt set
def test_function_call_non_openai_model_litellm_mod_set(): def test_function_call_non_openai_model_litellm_mod_set():
litellm.add_function_to_prompt = True litellm.add_function_to_prompt = True
try: try:
model = "claude-instant-1" model = "claude-instant-1"
messages=[{"role": "user", "content": "what's the weather in sf?"}] messages = [{"role": "user", "content": "what's the weather in sf?"}]
functions = [ functions = [
{ {
"name": "get_current_weather", "name": "get_current_weather",
@ -57,20 +59,20 @@ def test_function_call_non_openai_model_litellm_mod_set():
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA" "description": "The city and state, e.g. San Francisco, CA",
}, },
"unit": { "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
"type": "string", },
"enum": ["celsius", "fahrenheit"] "required": ["location"],
}
}, },
"required": ["location"]
}
} }
] ]
response = litellm.completion(model=model, messages=messages, functions=functions) response = litellm.completion(
print(f'response: {response}') model=model, messages=messages, functions=functions
)
print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f'An error occurred {e}') pytest.fail(f"An error occurred {e}")
# test_function_call_non_openai_model_litellm_mod_set() # test_function_call_non_openai_model_litellm_mod_set()

View file

@ -1,4 +1,3 @@
import sys, os import sys, os
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
@ -27,11 +26,11 @@ def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file # Define the path to the vertex_key.json file
print("loading vertex ai credentials") print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + '/vertex_key.json' vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary # Read the existing content of the file or create an empty dictionary
try: try:
with open(vertex_key_path, 'r') as file: with open(vertex_key_path, "r") as file:
# Read the file content # Read the file content
print("Read vertexai file path") print("Read vertexai file path")
content = file.read() content = file.read()
@ -55,13 +54,13 @@ def load_vertex_ai_credentials():
service_account_key_data["private_key"] = private_key service_account_key_data["private_key"] = private_key
# Create a temporary file # Create a temporary file
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file: with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary file # Write the updated content to the temporary file
json.dump(service_account_key_data, temp_file, indent=2) json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name) os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
@pytest.mark.asyncio @pytest.mark.asyncio
async def get_response(): async def get_response():
@ -89,43 +88,80 @@ def test_vertex_ai():
import random import random
load_vertex_ai_credentials() load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = (
litellm.set_verbose=False litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
litellm.set_verbose = False
litellm.vertex_project = "reliablekeys" litellm.vertex_project = "reliablekeys"
test_models = random.sample(test_models, 1) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
try: try:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model # our account does not have access to this model
continue continue
print("making request", model) print("making request", model)
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}], temperature=0.7) response = completion(
model=model,
messages=[{"role": "user", "content": "hi"}],
temperature=0.7,
)
print("\nModel Response", response) print("\nModel Response", response)
print(response) print(response)
assert type(response.choices[0].message.content) == str assert type(response.choices[0].message.content) == str
assert len(response.choices[0].message.content) > 1 assert len(response.choices[0].message.content) > 1
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_vertex_ai() # test_vertex_ai()
def test_vertex_ai_stream(): def test_vertex_ai_stream():
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose=False litellm.set_verbose = False
litellm.vertex_project = "reliablekeys" litellm.vertex_project = "reliablekeys"
import random import random
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
try: try:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model # our account does not have access to this model
continue continue
print("making request", model) print("making request", model)
response = completion(model=model, messages=[{"role": "user", "content": "write 10 line code code for saying hi"}], stream=True) response = completion(
model=model,
messages=[
{"role": "user", "content": "write 10 line code code for saying hi"}
],
stream=True,
)
completed_str = "" completed_str = ""
for chunk in response: for chunk in response:
print(chunk) print(chunk)
@ -137,47 +173,86 @@ def test_vertex_ai_stream():
assert len(completed_str) > 4 assert len(completed_str) > 4
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_vertex_ai_stream() # test_vertex_ai_stream()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_vertexai_response(): async def test_async_vertexai_response():
import random import random
load_vertex_ai_credentials() load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
print(f'model being tested in async call: {model}') print(f"model being tested in async call: {model}")
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model # our account does not have access to this model
continue continue
try: try:
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5) response = await acompletion(
model=model, messages=messages, temperature=0.7, timeout=5
)
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_response()) # asyncio.run(test_async_vertexai_response())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_vertexai_streaming_response(): async def test_async_vertexai_streaming_response():
import random import random
load_vertex_ai_credentials() load_vertex_ai_credentials()
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models test_models = (
litellm.vertex_chat_models
+ litellm.vertex_code_chat_models
+ litellm.vertex_text_models
+ litellm.vertex_code_text_models
)
test_models = random.sample(test_models, 1) test_models = random.sample(test_models, 1)
test_models += litellm.vertex_language_models # always test gemini-pro test_models += litellm.vertex_language_models # always test gemini-pro
for model in test_models: for model in test_models:
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]: if model in [
"code-gecko",
"code-gecko@001",
"code-gecko@002",
"code-gecko@latest",
"code-bison@001",
"text-bison@001",
]:
# our account does not have access to this model # our account does not have access to this model
continue continue
try: try:
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True) response = await acompletion(
model="gemini-pro",
messages=messages,
temperature=0.7,
timeout=5,
stream=True,
)
print(f"response: {response}") print(f"response: {response}")
complete_response = "" complete_response = ""
async for chunk in response: async for chunk in response:
@ -191,38 +266,40 @@ async def test_async_vertexai_streaming_response():
print(e) print(e)
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# asyncio.run(test_async_vertexai_streaming_response()) # asyncio.run(test_async_vertexai_streaming_response())
def test_gemini_pro_vision(): def test_gemini_pro_vision():
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
litellm.num_retries=0 litellm.num_retries = 0
resp = litellm.completion( resp = litellm.completion(
model = "vertex_ai/gemini-pro-vision", model="vertex_ai/gemini-pro-vision",
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "Whats in this image?"},
"type": "text",
"text": "Whats in this image?"
},
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
} },
} },
] ],
} }
], ],
) )
print(resp) print(resp)
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise e raise e
# test_gemini_pro_vision() # test_gemini_pro_vision()

View file

@ -11,8 +11,10 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
from litellm import completion, acompletion, acreate from litellm import completion, acompletion, acreate
litellm.num_retries = 3 litellm.num_retries = 3
def test_sync_response(): def test_sync_response():
litellm.set_verbose = False litellm.set_verbose = False
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
@ -24,28 +26,42 @@ def test_sync_response():
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_sync_response() # test_sync_response()
def test_sync_response_anyscale(): def test_sync_response_anyscale():
litellm.set_verbose = False litellm.set_verbose = False
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5) response = completion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
timeout=5,
)
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_sync_response_anyscale() # test_sync_response_anyscale()
def test_async_response_openai(): def test_async_response_openai():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
async def test_get_response(): async def test_get_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages, timeout=5) response = await acompletion(
model="gpt-3.5-turbo", messages=messages, timeout=5
)
print(f"response: {response}") print(f"response: {response}")
print(f"response ms: {response._response_ms}") print(f"response ms: {response._response_ms}")
except litellm.Timeout as e: except litellm.Timeout as e:
@ -56,16 +72,25 @@ def test_async_response_openai():
asyncio.run(test_get_response()) asyncio.run(test_get_response())
# test_async_response_openai() # test_async_response_openai()
def test_async_response_azure(): def test_async_response_azure():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
async def test_get_response(): async def test_get_response():
user_message = "What do you know?" user_message = "What do you know?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY")) response = await acompletion(
model="azure/gpt-turbo",
messages=messages,
base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"),
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
)
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
@ -74,17 +99,24 @@ def test_async_response_azure():
asyncio.run(test_get_response()) asyncio.run(test_get_response())
# test_async_response_azure() # test_async_response_azure()
def test_async_anyscale_response(): def test_async_anyscale_response():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
async def test_get_response(): async def test_get_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5) response = await acompletion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
timeout=5,
)
# response = await response # response = await response
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e: except litellm.Timeout as e:
@ -94,16 +126,21 @@ def test_async_anyscale_response():
asyncio.run(test_get_response()) asyncio.run(test_get_response())
# test_async_anyscale_response() # test_async_anyscale_response()
def test_get_response_streaming(): def test_get_response_streaming():
import asyncio import asyncio
async def test_async_call(): async def test_async_call():
user_message = "write a short poem in one sentence" user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
litellm.set_verbose = True litellm.set_verbose = True
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5) response = await acompletion(
model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5
)
print(type(response)) print(type(response))
import inspect import inspect
@ -121,24 +158,34 @@ def test_get_response_streaming():
assert output is not None, "output cannot be None." assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str" assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0." assert len(output) > 0, "Length of output needs to be greater than 0."
print(f'output: {output}') print(f"output: {output}")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_streaming() # test_get_response_streaming()
def test_get_response_non_openai_streaming(): def test_get_response_non_openai_streaming():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
litellm.num_retries = 0 litellm.num_retries = 0
async def test_async_call(): async def test_async_call():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True, timeout=5) response = await acompletion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
stream=True,
timeout=5,
)
print(type(response)) print(type(response))
import inspect import inspect
@ -163,6 +210,8 @@ def test_get_response_non_openai_streaming():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
return response return response
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_non_openai_streaming() # test_get_response_non_openai_streaming()

View file

@ -3,14 +3,15 @@
import sys, os, time, inspect, asyncio, traceback import sys, os, time, inspect, asyncio, traceback
from datetime import datetime from datetime import datetime
import pytest import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
import openai, litellm, uuid import openai, litellm, uuid
from openai import AsyncAzureOpenAI from openai import AsyncAzureOpenAI
client = AsyncAzureOpenAI( client = AsyncAzureOpenAI(
api_key=os.getenv("AZURE_API_KEY"), api_key=os.getenv("AZURE_API_KEY"),
azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
api_version=os.getenv("AZURE_API_VERSION") api_version=os.getenv("AZURE_API_VERSION"),
) )
model_list = [ model_list = [
@ -20,59 +21,84 @@ model_list = [
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION") "api_version": os.getenv("AZURE_API_VERSION"),
} },
} }
] ]
router = litellm.Router(model_list=model_list) router = litellm.Router(model_list=model_list)
async def _openai_completion(): async def _openai_completion():
try: try:
start_time = time.time() start_time = time.time()
response = await client.chat.completions.create( response = await client.chat.completions.create(
model="chatgpt-v-2", model="chatgpt-v-2",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
stream=True stream=True,
) )
time_to_first_token = None time_to_first_token = None
first_token_ts = None first_token_ts = None
init_chunk = None init_chunk = None
async for chunk in response: async for chunk in response:
if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: if (
time_to_first_token is None
and len(chunk.choices) > 0
and chunk.choices[0].delta.content is not None
):
first_token_ts = time.time() first_token_ts = time.time()
time_to_first_token = first_token_ts - start_time time_to_first_token = first_token_ts - start_time
init_chunk = chunk init_chunk = chunk
end_time = time.time() end_time = time.time()
print("OpenAI Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time) print(
"OpenAI Call: ",
init_chunk,
start_time,
first_token_ts,
time_to_first_token,
end_time,
)
return time_to_first_token return time_to_first_token
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
async def _router_completion(): async def _router_completion():
try: try:
start_time = time.time() start_time = time.time()
response = await router.acompletion( response = await router.acompletion(
model="azure-test", model="azure-test",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
stream=True stream=True,
) )
time_to_first_token = None time_to_first_token = None
first_token_ts = None first_token_ts = None
init_chunk = None init_chunk = None
async for chunk in response: async for chunk in response:
if time_to_first_token is None and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: if (
time_to_first_token is None
and len(chunk.choices) > 0
and chunk.choices[0].delta.content is not None
):
first_token_ts = time.time() first_token_ts = time.time()
time_to_first_token = first_token_ts - start_time time_to_first_token = first_token_ts - start_time
init_chunk = chunk init_chunk = chunk
end_time = time.time() end_time = time.time()
print("Router Call: ",init_chunk, start_time, first_token_ts, time_to_first_token, end_time - first_token_ts) print(
"Router Call: ",
init_chunk,
start_time,
first_token_ts,
time_to_first_token,
end_time - first_token_ts,
)
return time_to_first_token return time_to_first_token
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
async def test_azure_completion_streaming(): async def test_azure_completion_streaming():
""" """
Test azure streaming call - measure on time to first (non-null) token. Test azure streaming call - measure on time to first (non-null) token.
@ -85,7 +111,7 @@ async def test_azure_completion_streaming():
total_time = 0 total_time = 0
for item in successful_completions: for item in successful_completions:
total_time += item total_time += item
avg_openai_time = total_time/3 avg_openai_time = total_time / 3
## ROUTER AVG. TIME ## ROUTER AVG. TIME
tasks = [_router_completion() for _ in range(n)] tasks = [_router_completion() for _ in range(n)]
chat_completions = await asyncio.gather(*tasks) chat_completions = await asyncio.gather(*tasks)
@ -93,9 +119,10 @@ async def test_azure_completion_streaming():
total_time = 0 total_time = 0
for item in successful_completions: for item in successful_completions:
total_time += item total_time += item
avg_router_time = total_time/3 avg_router_time = total_time / 3
## COMPARE ## COMPARE
print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}") print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}")
assert avg_router_time < avg_openai_time + 0.5 assert avg_router_time < avg_openai_time + 0.5
# asyncio.run(test_azure_completion_streaming()) # asyncio.run(test_azure_completion_streaming())

View file

@ -5,6 +5,7 @@
import sys, os import sys, os
import traceback import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
@ -18,6 +19,7 @@ user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
model_val = None model_val = None
def test_completion_with_no_model(): def test_completion_with_no_model():
# test on empty # test on empty
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -32,6 +34,7 @@ def test_completion_with_empty_model():
print(f"error occurred: {e}") print(f"error occurred: {e}")
pass pass
# def test_completion_catch_nlp_exception(): # def test_completion_catch_nlp_exception():
# TEMP commented out NLP cloud API is unstable # TEMP commented out NLP cloud API is unstable
# try: # try:
@ -64,6 +67,7 @@ def test_completion_with_empty_model():
# test_completion_catch_nlp_exception() # test_completion_catch_nlp_exception()
def test_completion_invalid_param_cohere(): def test_completion_invalid_param_cohere():
try: try:
response = completion(model="command-nightly", messages=messages, top_p=1) response = completion(model="command-nightly", messages=messages, top_p=1)
@ -72,14 +76,18 @@ def test_completion_invalid_param_cohere():
if "Unsupported parameters passed: top_p" in str(e): if "Unsupported parameters passed: top_p" in str(e):
pass pass
else: else:
pytest.fail(f'An error occurred {e}') pytest.fail(f"An error occurred {e}")
# test_completion_invalid_param_cohere() # test_completion_invalid_param_cohere()
def test_completion_function_call_cohere(): def test_completion_function_call_cohere():
try: try:
response = completion(model="command-nightly", messages=messages, functions=["TEST-FUNCTION"]) response = completion(
pytest.fail(f'An error occurred {e}') model="command-nightly", messages=messages, functions=["TEST-FUNCTION"]
)
pytest.fail(f"An error occurred {e}")
except Exception as e: except Exception as e:
print(e) print(e)
pass pass
@ -87,10 +95,14 @@ def test_completion_function_call_cohere():
# test_completion_function_call_cohere() # test_completion_function_call_cohere()
def test_completion_function_call_openai(): def test_completion_function_call_openai():
try: try:
messages = [{"role": "user", "content": "What is the weather like in Boston?"}] messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
response = completion(model="gpt-3.5-turbo", messages=messages, functions=[ response = completion(
model="gpt-3.5-turbo",
messages=messages,
functions=[
{ {
"name": "get_current_weather", "name": "get_current_weather",
"description": "Get the current weather in a given location", "description": "Get the current weather in a given location",
@ -99,23 +111,26 @@ def test_completion_function_call_openai():
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA" "description": "The city and state, e.g. San Francisco, CA",
}, },
"unit": { "unit": {
"type": "string", "type": "string",
"enum": ["celsius", "fahrenheit"] "enum": ["celsius", "fahrenheit"],
} },
},
"required": ["location"],
}, },
"required": ["location"]
} }
} ],
]) )
print(f"response: {response}") print(f"response: {response}")
except: except:
pass pass
# test_completion_function_call_openai() # test_completion_function_call_openai()
def test_completion_with_no_provider(): def test_completion_with_no_provider():
# test on empty # test on empty
try: try:
@ -125,6 +140,7 @@ def test_completion_with_no_provider():
print(f"error occurred: {e}") print(f"error occurred: {e}")
pass pass
# test_completion_with_no_provider() # test_completion_with_no_provider()
# # bad key # # bad key
# temp_key = os.environ.get("OPENAI_API_KEY") # temp_key = os.environ.get("OPENAI_API_KEY")

View file

@ -4,15 +4,24 @@
import sys, os import sys, os
import traceback import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from openai import APITimeoutError as Timeout from openai import APITimeoutError as Timeout
import litellm import litellm
litellm.num_retries = 0 litellm.num_retries = 0
from litellm import batch_completion, batch_completion_models, completion, batch_completion_models_all_responses from litellm import (
batch_completion,
batch_completion_models,
completion,
batch_completion_models_all_responses,
)
# litellm.set_verbose=True # litellm.set_verbose=True
def test_batch_completions(): def test_batch_completions():
messages = [[{"role": "user", "content": "write a short poem"}] for _ in range(3)] messages = [[{"role": "user", "content": "write a short poem"}] for _ in range(3)]
model = "j2-mid" model = "j2-mid"
@ -23,43 +32,50 @@ def test_batch_completions():
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
temperature=0.2, temperature=0.2,
request_timeout=1 request_timeout=1,
) )
print(result) print(result)
print(len(result)) print(len(result))
assert(len(result)==3) assert len(result) == 3
except Timeout as e: except Timeout as e:
print(f"IN TIMEOUT") print(f"IN TIMEOUT")
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An error occurred: {e}") pytest.fail(f"An error occurred: {e}")
test_batch_completions() test_batch_completions()
def test_batch_completions_models(): def test_batch_completions_models():
try: try:
result = batch_completion_models( result = batch_completion_models(
models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"], models=["gpt-3.5-turbo", "gpt-3.5-turbo", "gpt-3.5-turbo"],
messages=[{"role": "user", "content": "Hey, how's it going"}] messages=[{"role": "user", "content": "Hey, how's it going"}],
) )
print(result) print(result)
except Timeout as e: except Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An error occurred: {e}") pytest.fail(f"An error occurred: {e}")
# test_batch_completions_models() # test_batch_completions_models()
def test_batch_completion_models_all_responses(): def test_batch_completion_models_all_responses():
try: try:
responses = batch_completion_models_all_responses( responses = batch_completion_models_all_responses(
models=["j2-light", "claude-instant-1.2"], models=["j2-light", "claude-instant-1.2"],
messages=[{"role": "user", "content": "write a poem"}], messages=[{"role": "user", "content": "write a poem"}],
max_tokens=10 max_tokens=10,
) )
print(responses) print(responses)
assert(len(responses) == 2) assert len(responses) == 2
except Timeout as e: except Timeout as e:
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"An error occurred: {e}") pytest.fail(f"An error occurred: {e}")
# test_batch_completion_models_all_responses()
# test_batch_completion_models_all_responses()

View file

@ -14,6 +14,7 @@ import litellm
from litellm import embedding, completion from litellm import embedding, completion
from litellm.caching import Cache from litellm.caching import Cache
import random import random
# litellm.set_verbose=True # litellm.set_verbose=True
messages = [{"role": "user", "content": "who is ishaan Github? "}] messages = [{"role": "user", "content": "who is ishaan Github? "}]
@ -22,14 +23,18 @@ messages = [{"role": "user", "content": "who is ishaan Github? "}]
import random import random
import string import string
def generate_random_word(length=4): def generate_random_word(length=4):
letters = string.ascii_lowercase letters = string.ascii_lowercase
return ''.join(random.choice(letters) for _ in range(length)) return "".join(random.choice(letters) for _ in range(length))
messages = [{"role": "user", "content": "who is ishaan 5222"}] messages = [{"role": "user", "content": "who is ishaan 5222"}]
def test_caching_v2(): # test in memory cache def test_caching_v2(): # test in memory cache
try: try:
litellm.set_verbose=True litellm.set_verbose = True
litellm.cache = Cache() litellm.cache = Cache()
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
@ -38,7 +43,10 @@ def test_caching_v2(): # test in memory cache
litellm.cache = None # disable cache litellm.cache = None # disable cache
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
@ -46,12 +54,14 @@ def test_caching_v2(): # test in memory cache
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_caching_v2() # test_caching_v2()
def test_caching_with_models_v2(): def test_caching_with_models_v2():
messages = [{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}] messages = [
{"role": "user", "content": "who is ishaan CTO of litellm from litellm 2023"}
]
litellm.cache = Cache() litellm.cache = Cache()
print("test2 for caching") print("test2 for caching")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
@ -63,34 +73,51 @@ def test_caching_with_models_v2():
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: if (
response3["choices"][0]["message"]["content"]
== response2["choices"][0]["message"]["content"]
):
# if models are different, it should not return cached response # if models are different, it should not return cached response
print(f"response2: {response2}") print(f"response2: {response2}")
print(f"response3: {response3}") print(f"response3: {response3}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
):
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
# test_caching_with_models_v2() # test_caching_with_models_v2()
embedding_large_text = """ embedding_large_text = (
"""
small text small text
""" * 5 """
* 5
)
# # test_caching_with_models() # # test_caching_with_models()
def test_embedding_caching(): def test_embedding_caching():
import time import time
litellm.cache = Cache() litellm.cache = Cache()
text_to_embed = [embedding_large_text] text_to_embed = [embedding_large_text]
start_time = time.time() start_time = time.time()
embedding1 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True) embedding1 = embedding(
model="text-embedding-ada-002", input=text_to_embed, caching=True
)
end_time = time.time() end_time = time.time()
print(f"Embedding 1 response time: {end_time - start_time} seconds") print(f"Embedding 1 response time: {end_time - start_time} seconds")
time.sleep(1) time.sleep(1)
start_time = time.time() start_time = time.time()
embedding2 = embedding(model="text-embedding-ada-002", input=text_to_embed, caching=True) embedding2 = embedding(
model="text-embedding-ada-002", input=text_to_embed, caching=True
)
end_time = time.time() end_time = time.time()
print(f"embedding2: {embedding2}") print(f"embedding2: {embedding2}")
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
@ -99,28 +126,29 @@ def test_embedding_caching():
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]:
print(f"embedding1: {embedding1}") print(f"embedding1: {embedding1}")
print(f"embedding2: {embedding2}") print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed") pytest.fail("Error occurred: Embedding caching failed")
# test_embedding_caching() # test_embedding_caching()
def test_embedding_caching_azure(): def test_embedding_caching_azure():
print("Testing azure embedding caching") print("Testing azure embedding caching")
import time import time
litellm.cache = Cache() litellm.cache = Cache()
text_to_embed = [embedding_large_text] text_to_embed = [embedding_large_text]
api_key = os.environ['AZURE_API_KEY'] api_key = os.environ["AZURE_API_KEY"]
api_base = os.environ['AZURE_API_BASE'] api_base = os.environ["AZURE_API_BASE"]
api_version = os.environ['AZURE_API_VERSION'] api_version = os.environ["AZURE_API_VERSION"]
os.environ['AZURE_API_VERSION'] = ""
os.environ['AZURE_API_BASE'] = ""
os.environ['AZURE_API_KEY'] = ""
os.environ["AZURE_API_VERSION"] = ""
os.environ["AZURE_API_BASE"] = ""
os.environ["AZURE_API_KEY"] = ""
start_time = time.time() start_time = time.time()
print("AZURE CONFIGS") print("AZURE CONFIGS")
@ -133,7 +161,7 @@ def test_embedding_caching_azure():
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
caching=True caching=True,
) )
end_time = time.time() end_time = time.time()
print(f"Embedding 1 response time: {end_time - start_time} seconds") print(f"Embedding 1 response time: {end_time - start_time} seconds")
@ -146,7 +174,7 @@ def test_embedding_caching_azure():
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
caching=True caching=True,
) )
end_time = time.time() end_time = time.time()
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
@ -155,14 +183,15 @@ def test_embedding_caching_azure():
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']: if embedding2["data"][0]["embedding"] != embedding1["data"][0]["embedding"]:
print(f"embedding1: {embedding1}") print(f"embedding1: {embedding1}")
print(f"embedding2: {embedding2}") print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed") pytest.fail("Error occurred: Embedding caching failed")
os.environ['AZURE_API_VERSION'] = api_version os.environ["AZURE_API_VERSION"] = api_version
os.environ['AZURE_API_BASE'] = api_base os.environ["AZURE_API_BASE"] = api_base
os.environ['AZURE_API_KEY'] = api_key os.environ["AZURE_API_KEY"] = api_key
# test_embedding_caching_azure() # test_embedding_caching_azure()
@ -170,13 +199,28 @@ def test_embedding_caching_azure():
def test_redis_cache_completion(): def test_redis_cache_completion():
litellm.set_verbose = False litellm.set_verbose = False
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache random_number = random.randint(
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}] 1, 100000
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) ) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test2 for caching") print("test2 for caching")
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20) response1 = completion(
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20) model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
response3 = completion(model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5) )
response2 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
)
response3 = completion(
model="gpt-3.5-turbo", messages=messages, caching=True, temperature=0.5
)
response4 = completion(model="command-nightly", messages=messages, caching=True) response4 = completion(model="command-nightly", messages=messages, caching=True)
print("\nresponse 1", response1) print("\nresponse 1", response1)
@ -192,49 +236,88 @@ def test_redis_cache_completion():
1 & 3 should be different, since input params are diff 1 & 3 should be different, since input params are diff
1 & 4 should be diff, since models are diff 1 & 4 should be diff, since models are diff
""" """
if response1['choices'][0]['message']['content'] != response2['choices'][0]['message']['content']: # 1 and 2 should be the same if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
): # 1 and 2 should be the same
# 1&2 have the exact same input params. This MUST Be a CACHE HIT # 1&2 have the exact same input params. This MUST Be a CACHE HIT
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
if response1['choices'][0]['message']['content'] == response3['choices'][0]['message']['content']: if (
response1["choices"][0]["message"]["content"]
== response3["choices"][0]["message"]["content"]
):
# if input params like seed, max_tokens are diff it should NOT be a cache hit # if input params like seed, max_tokens are diff it should NOT be a cache hit
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response3: {response3}") print(f"response3: {response3}")
pytest.fail(f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:") pytest.fail(
if response1['choices'][0]['message']['content'] == response4['choices'][0]['message']['content']: f"Response 1 == response 3. Same model, diff params shoudl not cache Error occurred:"
)
if (
response1["choices"][0]["message"]["content"]
== response4["choices"][0]["message"]["content"]
):
# if models are different, it should not return cached response # if models are different, it should not return cached response
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response4: {response4}") print(f"response4: {response4}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
# test_redis_cache_completion() # test_redis_cache_completion()
def test_redis_cache_completion_stream(): def test_redis_cache_completion_stream():
try: try:
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
litellm.callbacks = [] litellm.callbacks = []
litellm.set_verbose = True litellm.set_verbose = True
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache random_number = random.randint(
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}] 1, 100000
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) ) # add a random number to ensure it's always adding / reading from cache
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion") print("test for caching, streaming + completion")
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True) response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=0.2,
stream=True,
)
response_1_content = "" response_1_content = ""
for chunk in response1: for chunk in response1:
print(chunk) print(chunk)
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
time.sleep(0.5) time.sleep(0.5)
response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True) response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=0.2,
stream=True,
)
response_2_content = "" response_2_content = ""
for chunk in response2: for chunk in response2:
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print("\nresponse 1", response_1_content) print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = [] litellm.success_callback = []
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
@ -247,99 +330,171 @@ def test_redis_cache_completion_stream():
1 & 2 should be exactly the same 1 & 2 should be exactly the same
""" """
# test_redis_cache_completion_stream() # test_redis_cache_completion_stream()
def test_redis_cache_acompletion_stream(): def test_redis_cache_acompletion_stream():
import asyncio import asyncio
try: try:
litellm.set_verbose = True litellm.set_verbose = True
random_word = generate_random_word() random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}] messages = [
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) {
"role": "user",
"content": f"write a one sentence poem about: {random_word}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion") print("test for caching, streaming + completion")
response_1_content = "" response_1_content = ""
response_2_content = "" response_2_content = ""
async def call1(): async def call1():
nonlocal response_1_content nonlocal response_1_content
response1 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True) response1 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1: async for chunk in response1:
print(chunk) print(chunk)
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
asyncio.run(call1()) asyncio.run(call1())
time.sleep(0.5) time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n") print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2(): async def call2():
nonlocal response_2_content nonlocal response_2_content
response2 = await litellm.acompletion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=1, stream=True) response2 = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2: async for chunk in response2:
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content) print(response_2_content)
asyncio.run(call2()) asyncio.run(call2())
print("\nresponse 1", response_1_content) print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
except Exception as e: except Exception as e:
print(e) print(e)
raise e raise e
# test_redis_cache_acompletion_stream() # test_redis_cache_acompletion_stream()
def test_redis_cache_acompletion_stream_bedrock(): def test_redis_cache_acompletion_stream_bedrock():
import asyncio import asyncio
try: try:
litellm.set_verbose = True litellm.set_verbose = True
random_word = generate_random_word() random_word = generate_random_word()
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}] messages = [
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) {
"role": "user",
"content": f"write a one sentence poem about: {random_word}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion") print("test for caching, streaming + completion")
response_1_content = "" response_1_content = ""
response_2_content = "" response_2_content = ""
async def call1(): async def call1():
nonlocal response_1_content nonlocal response_1_content
response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True) response1 = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response1: async for chunk in response1:
print(chunk) print(chunk)
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
asyncio.run(call1()) asyncio.run(call1())
time.sleep(0.5) time.sleep(0.5)
print("\n\n Response 1 content: ", response_1_content, "\n\n") print("\n\n Response 1 content: ", response_1_content, "\n\n")
async def call2(): async def call2():
nonlocal response_2_content nonlocal response_2_content
response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True) response2 = await litellm.acompletion(
model="bedrock/anthropic.claude-v1",
messages=messages,
max_tokens=40,
temperature=1,
stream=True,
)
async for chunk in response2: async for chunk in response2:
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print(response_2_content) print(response_2_content)
asyncio.run(call2()) asyncio.run(call2())
print("\nresponse 1", response_1_content) print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
except Exception as e: except Exception as e:
print(e) print(e)
raise e raise e
# test_redis_cache_acompletion_stream_bedrock() # test_redis_cache_acompletion_stream_bedrock()
# redis cache with custom keys # redis cache with custom keys
def custom_get_cache_key(*args, **kwargs): def custom_get_cache_key(*args, **kwargs):
# return key to use for your cache: # return key to use for your cache:
key = kwargs.get("model", "") + str(kwargs.get("messages", "")) + str(kwargs.get("temperature", "")) + str(kwargs.get("logit_bias", "")) key = (
kwargs.get("model", "")
+ str(kwargs.get("messages", ""))
+ str(kwargs.get("temperature", ""))
+ str(kwargs.get("logit_bias", ""))
)
return key return key
def test_custom_redis_cache_with_key(): def test_custom_redis_cache_with_key():
messages = [{"role": "user", "content": "write a one line story"}] messages = [{"role": "user", "content": "write a one line story"}]
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.cache.get_cache_key = custom_get_cache_key litellm.cache.get_cache_key = custom_get_cache_key
local_cache = {} local_cache = {}
@ -356,53 +511,71 @@ def test_custom_redis_cache_with_key():
# patch this redis cache get and set call # patch this redis cache get and set call
response1 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3) response1 = completion(
response2 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=True, num_retries=3) model="gpt-3.5-turbo",
response3 = completion(model="gpt-3.5-turbo", messages=messages, temperature=1, caching=False, num_retries=3) messages=messages,
temperature=1,
caching=True,
num_retries=3,
)
response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
temperature=1,
caching=True,
num_retries=3,
)
response3 = completion(
model="gpt-3.5-turbo",
messages=messages,
temperature=1,
caching=False,
num_retries=3,
)
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
print(f"response3: {response3}") print(f"response3: {response3}")
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']: if (
response3["choices"][0]["message"]["content"]
== response2["choices"][0]["message"]["content"]
):
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
# test_custom_redis_cache_with_key() # test_custom_redis_cache_with_key()
def test_cache_override(): def test_cache_override():
# test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set # test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set
# in this case it should not return cached responses # in this case it should not return cached responses
litellm.cache = Cache() litellm.cache = Cache()
print("Testing cache override") print("Testing cache override")
litellm.set_verbose=True litellm.set_verbose = True
# test embedding # test embedding
response1 = embedding( response1 = embedding(
model = "text-embedding-ada-002", model="text-embedding-ada-002", input=["hello who are you"], caching=False
input=[
"hello who are you"
],
caching = False
) )
start_time = time.time() start_time = time.time()
response2 = embedding( response2 = embedding(
model = "text-embedding-ada-002", model="text-embedding-ada-002", input=["hello who are you"], caching=False
input=[
"hello who are you"
],
caching = False
) )
end_time = time.time() end_time = time.time()
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
assert end_time - start_time > 0.1 # ensure 2nd response comes in over 0.1s. This should not be cached. assert (
end_time - start_time > 0.1
) # ensure 2nd response comes in over 0.1s. This should not be cached.
# test_cache_override() # test_cache_override()
@ -411,10 +584,10 @@ def test_custom_redis_cache_params():
try: try:
litellm.cache = Cache( litellm.cache = Cache(
type="redis", type="redis",
host=os.environ['REDIS_HOST'], host=os.environ["REDIS_HOST"],
port=os.environ['REDIS_PORT'], port=os.environ["REDIS_PORT"],
password=os.environ['REDIS_PASSWORD'], password=os.environ["REDIS_PASSWORD"],
db = 0, db=0,
ssl=True, ssl=True,
ssl_certfile="./redis_user.crt", ssl_certfile="./redis_user.crt",
ssl_keyfile="./redis_user_private.key", ssl_keyfile="./redis_user_private.key",
@ -431,58 +604,126 @@ def test_custom_redis_cache_params():
def test_get_cache_key(): def test_get_cache_key():
from litellm.caching import Cache from litellm.caching import Cache
try: try:
print("Testing get_cache_key") print("Testing get_cache_key")
cache_instance = Cache() cache_instance = Cache()
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}} cache_key = cache_instance.get_cache_key(
**{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "write a one sentence poem about: 7510"}
],
"max_tokens": 40,
"temperature": 0.2,
"stream": True,
"litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b",
"litellm_logging_obj": {},
}
) )
cache_key_2 = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}} cache_key_2 = cache_instance.get_cache_key(
**{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "write a one sentence poem about: 7510"}
],
"max_tokens": 40,
"temperature": 0.2,
"stream": True,
"litellm_call_id": "ffe75e7e-8a07-431f-9a74-71a5b9f35f0b",
"litellm_logging_obj": {},
}
) )
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40" assert (
assert cache_key == cache_key_2, f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs" cache_key
== "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
)
assert (
cache_key == cache_key_2
), f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
embedding_cache_key = cache_instance.get_cache_key( embedding_cache_key = cache_instance.get_cache_key(
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/', **{
'api_key': '', 'api_version': '2023-07-01-preview', "model": "azure/azure-embedding-model",
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'], "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
'caching': True, "api_key": "",
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>" "api_version": "2023-07-01-preview",
"timeout": None,
"max_retries": 0,
"input": ["hi who is ishaan"],
"caching": True,
"client": "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
} }
) )
print(embedding_cache_key) print(embedding_cache_key)
assert embedding_cache_key == "model: azure/azure-embedding-modelinput: ['hi who is ishaan']", f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs" assert (
embedding_cache_key
== "model: azure/azure-embedding-modelinput: ['hi who is ishaan']"
), f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
# Proxy - embedding cache, test if embedding key, gets model_group and not model # Proxy - embedding cache, test if embedding key, gets model_group and not model
embedding_cache_key_2 = cache_instance.get_cache_key( embedding_cache_key_2 = cache_instance.get_cache_key(
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/', **{
'api_key': '', 'api_version': '2023-07-01-preview', "model": "azure/azure-embedding-model",
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'], "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
'caching': True, "api_key": "",
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>", "api_version": "2023-07-01-preview",
'proxy_server_request': {'url': 'http://0.0.0.0:8000/embeddings', "timeout": None,
'method': 'POST', "max_retries": 0,
'headers': "input": ["hi who is ishaan"],
{'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', "caching": True,
'content-length': '80'}, "client": "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
'body': {'model': 'azure-embedding-model', 'input': ['hi who is ishaan']}}, "proxy_server_request": {
'user': None, "url": "http://0.0.0.0:8000/embeddings",
'metadata': {'user_api_key': None, "method": "POST",
'headers': {'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', 'content-length': '80'}, "headers": {
'model_group': 'EMBEDDING_MODEL_GROUP', "host": "0.0.0.0:8000",
'deployment': 'azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview'}, "user-agent": "curl/7.88.1",
'model_info': {'mode': 'embedding', 'base_model': 'text-embedding-ada-002', 'id': '20b2b515-f151-4dd5-a74f-2231e2f54e29'}, "accept": "*/*",
'litellm_call_id': '2642e009-b3cd-443d-b5dd-bb7d56123b0e', 'litellm_logging_obj': '<litellm.utils.Logging object at 0x12f1bddb0>'} "content-type": "application/json",
"content-length": "80",
},
"body": {
"model": "azure-embedding-model",
"input": ["hi who is ishaan"],
},
},
"user": None,
"metadata": {
"user_api_key": None,
"headers": {
"host": "0.0.0.0:8000",
"user-agent": "curl/7.88.1",
"accept": "*/*",
"content-type": "application/json",
"content-length": "80",
},
"model_group": "EMBEDDING_MODEL_GROUP",
"deployment": "azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview",
},
"model_info": {
"mode": "embedding",
"base_model": "text-embedding-ada-002",
"id": "20b2b515-f151-4dd5-a74f-2231e2f54e29",
},
"litellm_call_id": "2642e009-b3cd-443d-b5dd-bb7d56123b0e",
"litellm_logging_obj": "<litellm.utils.Logging object at 0x12f1bddb0>",
}
) )
print(embedding_cache_key_2) print(embedding_cache_key_2)
assert embedding_cache_key_2 == "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']" assert (
embedding_cache_key_2
== "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
)
print("passed!") print("passed!")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
pytest.fail(f"Error occurred:", e) pytest.fail(f"Error occurred:", e)
test_get_cache_key() test_get_cache_key()
# test_custom_redis_cache_params() # test_custom_redis_cache_params()

View file

@ -18,15 +18,26 @@ from litellm import embedding, completion, Router
from litellm.caching import Cache from litellm.caching import Cache
messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}] messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
def test_caching_v2(): # test in memory cache def test_caching_v2(): # test in memory cache
try: try:
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2") litellm.cache = Cache(
type="redis",
host="os.environ/REDIS_HOST_2",
port="os.environ/REDIS_PORT_2",
password="os.environ/REDIS_PASSWORD_2",
ssl="os.environ/REDIS_SSL_2",
)
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True) response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
litellm.cache = None # disable cache litellm.cache = None # disable cache
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
raise Exception() raise Exception()
@ -34,6 +45,7 @@ def test_caching_v2(): # test in memory cache
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_caching_v2() # test_caching_v2()
@ -49,26 +61,41 @@ def test_caching_router():
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
} }
] ]
litellm.cache = Cache(type="redis", host="os.environ/REDIS_HOST_2", port="os.environ/REDIS_PORT_2", password="os.environ/REDIS_PASSWORD_2", ssl="os.environ/REDIS_SSL_2") litellm.cache = Cache(
router = Router(model_list=model_list, type="redis",
host="os.environ/REDIS_HOST_2",
port="os.environ/REDIS_PORT_2",
password="os.environ/REDIS_PASSWORD_2",
ssl="os.environ/REDIS_SSL_2",
)
router = Router(
model_list=model_list,
routing_strategy="simple-shuffle", routing_strategy="simple-shuffle",
set_verbose=False, set_verbose=False,
num_retries=1) # type: ignore num_retries=1,
) # type: ignore
response1 = completion(model="gpt-3.5-turbo", messages=messages) response1 = completion(model="gpt-3.5-turbo", messages=messages)
response2 = completion(model="gpt-3.5-turbo", messages=messages) response2 = completion(model="gpt-3.5-turbo", messages=messages)
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']: if (
response2["choices"][0]["message"]["content"]
!= response1["choices"][0]["message"]["content"]
):
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
litellm.cache = None # disable cache litellm.cache = None # disable cache
assert response2['choices'][0]['message']['content'] == response1['choices'][0]['message']['content'] assert (
response2["choices"][0]["message"]["content"]
== response1["choices"][0]["message"]["content"]
)
except Exception as e: except Exception as e:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_caching_router() # test_caching_router()

File diff suppressed because it is too large Load diff

View file

@ -28,6 +28,7 @@ def logger_fn(user_model_dict):
# print(f"user_model_dict: {user_model_dict}") # print(f"user_model_dict: {user_model_dict}")
pass pass
# normal call # normal call
def test_completion_custom_provider_model_name(): def test_completion_custom_provider_model_name():
try: try:
@ -41,25 +42,31 @@ def test_completion_custom_provider_model_name():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# completion with num retries + impact on exception mapping # completion with num retries + impact on exception mapping
def test_completion_with_num_retries(): def test_completion_with_num_retries():
try: try:
response = completion(model="j2-ultra", messages=[{"messages": "vibe", "bad": "message"}], num_retries=2) response = completion(
model="j2-ultra",
messages=[{"messages": "vibe", "bad": "message"}],
num_retries=2,
)
pytest.fail(f"Unmapped exception occurred") pytest.fail(f"Unmapped exception occurred")
except Exception as e: except Exception as e:
pass pass
# test_completion_with_num_retries() # test_completion_with_num_retries()
def test_completion_with_0_num_retries(): def test_completion_with_0_num_retries():
try: try:
litellm.set_verbose=False litellm.set_verbose = False
print("making request") print("making request")
# Use the completion function # Use the completion function
response = completion( response = completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"gm": "vibe", "role": "user"}], messages=[{"gm": "vibe", "role": "user"}],
max_retries=4 max_retries=4,
) )
print(response) print(response)
@ -69,5 +76,6 @@ def test_completion_with_0_num_retries():
print("exception", e) print("exception", e)
pass pass
# Call the test function # Call the test function
test_completion_with_0_num_retries() test_completion_with_0_num_retries()

View file

@ -15,77 +15,104 @@ from litellm import completion_with_config
config = { config = {
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"], "default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
"model": { "model": {
"claude-instant-1": { "claude-instant-1": {"needs_moderation": True},
"needs_moderation": True
},
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
"error_handling": { "error_handling": {
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"} "ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
} }
} },
} },
} }
def test_config_context_window_exceeded(): def test_config_context_window_exceeded():
try: try:
sample_text = "how does a court case get to the Supreme Court?" * 1000 sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config) response = completion_with_config(
model="gpt-3.5-turbo", messages=messages, config=config
)
print(response) print(response)
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_config_context_window_exceeded() # test_config_context_window_exceeded()
def test_config_context_moderation(): def test_config_context_moderation():
try: try:
messages=[{"role": "user", "content": "I want to kill them."}] messages = [{"role": "user", "content": "I want to kill them."}]
response = completion_with_config(model="claude-instant-1", messages=messages, config=config) response = completion_with_config(
model="claude-instant-1", messages=messages, config=config
)
print(response) print(response)
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_config_context_moderation() # test_config_context_moderation()
def test_config_context_default_fallback(): def test_config_context_default_fallback():
try: try:
messages=[{"role": "user", "content": "Hey, how's it going?"}] messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = completion_with_config(model="claude-instant-1", messages=messages, config=config, api_key="bad-key") response = completion_with_config(
model="claude-instant-1",
messages=messages,
config=config,
api_key="bad-key",
)
print(response) print(response)
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_config_context_default_fallback() # test_config_context_default_fallback()
config = { config = {
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"], "default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
"available_models": ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-0314", "gpt-4-0613", "available_models": [
"j2-ultra", "command-nightly", "togethercomputer/llama-2-70b-chat", "chat-bison", "chat-bison@001", "claude-2"], "gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"j2-ultra",
"command-nightly",
"togethercomputer/llama-2-70b-chat",
"chat-bison",
"chat-bison@001",
"claude-2",
],
"adapt_to_prompt_size": True, # type: ignore "adapt_to_prompt_size": True, # type: ignore
"model": { "model": {
"claude-instant-1": { "claude-instant-1": {"needs_moderation": True},
"needs_moderation": True
},
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
"error_handling": { "error_handling": {
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"} "ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
} }
} },
} },
} }
def test_config_context_adapt_to_prompt(): def test_config_context_adapt_to_prompt():
try: try:
sample_text = "how does a court case get to the Supreme Court?" * 1000 sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config) response = completion_with_config(
model="gpt-3.5-turbo", messages=messages, config=config
)
print(response) print(response)
except Exception as e: except Exception as e:
print(f"Exception: {e}") print(f"Exception: {e}")
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
test_config_context_adapt_to_prompt() test_config_context_adapt_to_prompt()

View file

@ -4,6 +4,8 @@ from dotenv import load_dotenv
import os import os
load_dotenv() load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try: try:
print(f"api_key: {api_key}") print(f"api_key: {api_key}")

View file

@ -2,6 +2,7 @@ from litellm.integrations.custom_logger import CustomLogger
import inspect import inspect
import litellm import litellm
class testCustomCallbackProxy(CustomLogger): class testCustomCallbackProxy(CustomLogger):
def __init__(self): def __init__(self):
self.success: bool = False # type: ignore self.success: bool = False # type: ignore
@ -24,7 +25,11 @@ class testCustomCallbackProxy(CustomLogger):
print(f"{blue_color_code}Initialized LiteLLM custom logger") print(f"{blue_color_code}Initialized LiteLLM custom logger")
try: try:
print(f"Logger Initialized with following methods:") print(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))] methods = [
method
for method in dir(self)
if inspect.ismethod(getattr(self, method))
]
# Pretty print the methods # Pretty print the methods
for method in methods: for method in methods:
@ -55,7 +60,10 @@ class testCustomCallbackProxy(CustomLogger):
self.async_success = True self.async_success = True
print("Value of async success: ", self.async_success) print("Value of async success: ", self.async_success)
print("\n kwargs: ", kwargs) print("\n kwargs: ", kwargs)
if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada": if (
kwargs.get("model") == "azure-embedding-model"
or kwargs.get("model") == "ada"
):
print("Got an embedding model", kwargs.get("model")) print("Got an embedding model", kwargs.get("model"))
print("Setting embedding success to True") print("Setting embedding success to True")
self.async_success_embedding = True self.async_success_embedding = True
@ -65,7 +73,6 @@ class testCustomCallbackProxy(CustomLogger):
if kwargs.get("stream") == True: if kwargs.get("stream") == True:
self.streaming_response_obj = response_obj self.streaming_response_obj = response_obj
self.async_completion_kwargs = kwargs self.async_completion_kwargs = kwargs
model = kwargs.get("model", None) model = kwargs.get("model", None)
@ -74,7 +81,9 @@ class testCustomCallbackProxy(CustomLogger):
# Access litellm_params passed to litellm.completion(), example access `metadata` # Access litellm_params passed to litellm.completion(), example access `metadata`
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here metadata = litellm_params.get(
"metadata", {}
) # headers passed to LiteLLM proxy, can be found here
# Calculate cost using litellm.completion_cost() # Calculate cost using litellm.completion_cost()
cost = litellm.completion_cost(completion_response=response_obj) cost = litellm.completion_cost(completion_response=response_obj)
@ -84,7 +93,6 @@ class testCustomCallbackProxy(CustomLogger):
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger)) print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
print( print(
f""" f"""
Model: {model}, Model: {model},
@ -98,7 +106,6 @@ class testCustomCallbackProxy(CustomLogger):
) )
return return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure") print(f"On Async Failure")
self.async_failure = True self.async_failure = True
@ -110,4 +117,5 @@ class testCustomCallbackProxy(CustomLogger):
self.async_completion_kwargs_fail = kwargs self.async_completion_kwargs_fail = kwargs
my_custom_logger = testCustomCallbackProxy() my_custom_logger = testCustomCallbackProxy()

View file

@ -3,7 +3,8 @@
import sys, os, time, inspect, asyncio, traceback import sys, os, time, inspect, asyncio, traceback
from datetime import datetime from datetime import datetime
import pytest import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List, Union from typing import Optional, Literal, List, Union
from litellm import completion, embedding, Cache from litellm import completion, embedding, Cache
import litellm import litellm
@ -25,14 +26,32 @@ from litellm.integrations.custom_logger import CustomLogger
## 1. litellm.completion() + litellm.embeddings() ## 1. litellm.completion() + litellm.embeddings()
## refer to test_custom_callback_input_router.py for the router + proxy tests ## refer to test_custom_callback_input_router.py for the router + proxy tests
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
class CompletionCustomHandler(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
""" """
The set of expected inputs to a custom handler for a The set of expected inputs to a custom handler for a
""" """
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
self.errors = [] self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = [] self.states: Optional[
List[
Literal[
"sync_pre_api_call",
"async_pre_api_call",
"post_api_call",
"sync_stream",
"async_stream",
"sync_success",
"async_success",
"sync_failure",
"async_failure",
]
]
] = []
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
try: try:
@ -42,13 +61,13 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## MESSAGES ## MESSAGES
assert isinstance(messages, list) assert isinstance(messages, list)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -63,18 +82,24 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT ## RESPONSE OBJECT
assert response_obj == None assert response_obj == None
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str)) assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response']) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(
assert isinstance(kwargs['log_event_type'], str) kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.iscoroutine(kwargs["original_response"])
or inspect.isasyncgen(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -89,18 +114,29 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse) assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) assert isinstance(kwargs["messages"], list) and isinstance(
assert isinstance(kwargs['optional_params'], dict) kwargs["messages"][0], dict
assert isinstance(kwargs['litellm_params'], dict) )
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(kwargs["input"], list)
assert isinstance(kwargs['log_event_type'], str) and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -115,18 +151,25 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse) assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) assert isinstance(kwargs["messages"], list) and isinstance(
assert isinstance(kwargs['optional_params'], dict) kwargs["messages"][0], dict
assert isinstance(kwargs['litellm_params'], dict) )
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(kwargs["input"], list)
assert isinstance(kwargs['log_event_type'], str) and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -141,18 +184,28 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT ## RESPONSE OBJECT
assert response_obj == None assert response_obj == None
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) assert isinstance(kwargs["messages"], list) and isinstance(
assert isinstance(kwargs['optional_params'], dict) kwargs["messages"][0], dict
assert isinstance(kwargs['litellm_params'], dict) )
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(kwargs["input"], list)
assert isinstance(kwargs['log_event_type'], str) and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -165,13 +218,15 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## MESSAGES ## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict) assert isinstance(messages, list) and isinstance(messages[0], dict)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) assert isinstance(kwargs["messages"], list) and isinstance(
assert isinstance(kwargs['optional_params'], dict) kwargs["messages"][0], dict
assert isinstance(kwargs['litellm_params'], dict) )
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -184,20 +239,28 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## END TIME ## END TIME
assert isinstance(end_time, datetime) assert isinstance(end_time, datetime)
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)) assert isinstance(
response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)
)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str)) assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(
assert isinstance(kwargs['log_event_type'], str) kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
@ -213,18 +276,24 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT ## RESPONSE OBJECT
assert response_obj == None assert response_obj == None
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['input'], (list, str, dict)) assert isinstance(kwargs["input"], (list, str, dict))
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(
assert isinstance(kwargs['log_event_type'], str) kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -236,29 +305,26 @@ def test_chat_openai_stream():
try: try:
customHandler = CompletionCustomHandler() customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
response = litellm.completion(model="gpt-3.5-turbo", response = litellm.completion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm sync openai"}],
"content": "Hi 👋 - i'm sync openai" )
}])
## test streaming ## test streaming
response = litellm.completion(model="gpt-3.5-turbo", response = litellm.completion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" stream=True,
}], )
stream=True)
for chunk in response: for chunk in response:
continue continue
## test failure callback ## test failure callback
try: try:
response = litellm.completion(model="gpt-3.5-turbo", response = litellm.completion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai"
}],
api_key="my-bad-key", api_key="my-bad-key",
stream=True) stream=True,
)
for chunk in response: for chunk in response:
continue continue
except: except:
@ -270,37 +336,36 @@ def test_chat_openai_stream():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_openai_stream() # test_chat_openai_stream()
## Test OpenAI + Async ## Test OpenAI + Async
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_chat_openai_stream(): async def test_async_chat_openai_stream():
try: try:
customHandler = CompletionCustomHandler() customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="gpt-3.5-turbo", response = await litellm.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" )
}])
## test streaming ## test streaming
response = await litellm.acompletion(model="gpt-3.5-turbo", response = await litellm.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" stream=True,
}], )
stream=True)
async for chunk in response: async for chunk in response:
continue continue
## test failure callback ## test failure callback
try: try:
response = await litellm.acompletion(model="gpt-3.5-turbo", response = await litellm.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai"
}],
api_key="my-bad-key", api_key="my-bad-key",
stream=True) stream=True,
)
async for chunk in response: async for chunk in response:
continue continue
except: except:
@ -312,36 +377,35 @@ async def test_async_chat_openai_stream():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_openai_stream()) # asyncio.run(test_async_chat_openai_stream())
## Test Azure + sync ## Test Azure + sync
def test_chat_azure_stream(): def test_chat_azure_stream():
try: try:
customHandler = CompletionCustomHandler() customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
response = litellm.completion(model="azure/chatgpt-v-2", response = litellm.completion(
messages=[{ model="azure/chatgpt-v-2",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
"content": "Hi 👋 - i'm sync azure" )
}])
# test streaming # test streaming
response = litellm.completion(model="azure/chatgpt-v-2", response = litellm.completion(
messages=[{ model="azure/chatgpt-v-2",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
"content": "Hi 👋 - i'm sync azure" stream=True,
}], )
stream=True)
for chunk in response: for chunk in response:
continue continue
# test failure callback # test failure callback
try: try:
response = litellm.completion(model="azure/chatgpt-v-2", response = litellm.completion(
messages=[{ model="azure/chatgpt-v-2",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
"content": "Hi 👋 - i'm sync azure"
}],
api_key="my-bad-key", api_key="my-bad-key",
stream=True) stream=True,
)
for chunk in response: for chunk in response:
continue continue
except: except:
@ -353,37 +417,36 @@ def test_chat_azure_stream():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_azure_stream() # test_chat_azure_stream()
## Test Azure + Async ## Test Azure + Async
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_chat_azure_stream(): async def test_async_chat_azure_stream():
try: try:
customHandler = CompletionCustomHandler() customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="azure/chatgpt-v-2", response = await litellm.acompletion(
messages=[{ model="azure/chatgpt-v-2",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
"content": "Hi 👋 - i'm async azure" )
}])
## test streaming ## test streaming
response = await litellm.acompletion(model="azure/chatgpt-v-2", response = await litellm.acompletion(
messages=[{ model="azure/chatgpt-v-2",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
"content": "Hi 👋 - i'm async azure" stream=True,
}], )
stream=True)
async for chunk in response: async for chunk in response:
continue continue
## test failure callback ## test failure callback
try: try:
response = await litellm.acompletion(model="azure/chatgpt-v-2", response = await litellm.acompletion(
messages=[{ model="azure/chatgpt-v-2",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
"content": "Hi 👋 - i'm async azure"
}],
api_key="my-bad-key", api_key="my-bad-key",
stream=True) stream=True,
)
async for chunk in response: async for chunk in response:
continue continue
except: except:
@ -395,36 +458,35 @@ async def test_async_chat_azure_stream():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_azure_stream()) # asyncio.run(test_async_chat_azure_stream())
## Test Bedrock + sync ## Test Bedrock + sync
def test_chat_bedrock_stream(): def test_chat_bedrock_stream():
try: try:
customHandler = CompletionCustomHandler() customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
response = litellm.completion(model="bedrock/anthropic.claude-v1", response = litellm.completion(
messages=[{ model="bedrock/anthropic.claude-v1",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}],
"content": "Hi 👋 - i'm sync bedrock" )
}])
# test streaming # test streaming
response = litellm.completion(model="bedrock/anthropic.claude-v1", response = litellm.completion(
messages=[{ model="bedrock/anthropic.claude-v1",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}],
"content": "Hi 👋 - i'm sync bedrock" stream=True,
}], )
stream=True)
for chunk in response: for chunk in response:
continue continue
# test failure callback # test failure callback
try: try:
response = litellm.completion(model="bedrock/anthropic.claude-v1", response = litellm.completion(
messages=[{ model="bedrock/anthropic.claude-v1",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}],
"content": "Hi 👋 - i'm sync bedrock"
}],
aws_region_name="my-bad-region", aws_region_name="my-bad-region",
stream=True) stream=True,
)
for chunk in response: for chunk in response:
continue continue
except: except:
@ -436,39 +498,38 @@ def test_chat_bedrock_stream():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_bedrock_stream() # test_chat_bedrock_stream()
## Test Bedrock + Async ## Test Bedrock + Async
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_chat_bedrock_stream(): async def test_async_chat_bedrock_stream():
try: try:
customHandler = CompletionCustomHandler() customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler] litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1", response = await litellm.acompletion(
messages=[{ model="bedrock/anthropic.claude-v1",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}],
"content": "Hi 👋 - i'm async bedrock" )
}])
# test streaming # test streaming
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1", response = await litellm.acompletion(
messages=[{ model="bedrock/anthropic.claude-v1",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}],
"content": "Hi 👋 - i'm async bedrock" stream=True,
}], )
stream=True)
print(f"response: {response}") print(f"response: {response}")
async for chunk in response: async for chunk in response:
print(f"chunk: {chunk}") print(f"chunk: {chunk}")
continue continue
## test failure callback ## test failure callback
try: try:
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1", response = await litellm.acompletion(
messages=[{ model="bedrock/anthropic.claude-v1",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}],
"content": "Hi 👋 - i'm async bedrock"
}],
aws_region_name="my-bad-key", aws_region_name="my-bad-key",
stream=True) stream=True,
)
async for chunk in response: async for chunk in response:
continue continue
except: except:
@ -480,8 +541,10 @@ async def test_async_chat_bedrock_stream():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_bedrock_stream()) # asyncio.run(test_async_chat_bedrock_stream())
# EMBEDDING # EMBEDDING
## Test OpenAI + Async ## Test OpenAI + Async
@pytest.mark.asyncio @pytest.mark.asyncio
@ -490,8 +553,9 @@ async def test_async_embedding_openai():
customHandler_success = CompletionCustomHandler() customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success] litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(model="azure/azure-embedding-model", response = await litellm.aembedding(
input=["good morning from litellm"]) model="azure/azure-embedding-model", input=["good morning from litellm"]
)
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}") print(f"customHandler_success.states: {customHandler_success.states}")
@ -500,9 +564,11 @@ async def test_async_embedding_openai():
# test failure callback # test failure callback
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
try: try:
response = await litellm.aembedding(model="text-embedding-ada-002", response = await litellm.aembedding(
model="text-embedding-ada-002",
input=["good morning from litellm"], input=["good morning from litellm"],
api_key="my-bad-key") api_key="my-bad-key",
)
except: except:
pass pass
await asyncio.sleep(1) await asyncio.sleep(1)
@ -513,8 +579,10 @@ async def test_async_embedding_openai():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_openai()) # asyncio.run(test_async_embedding_openai())
## Test Azure + Async ## Test Azure + Async
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_embedding_azure(): async def test_async_embedding_azure():
@ -522,8 +590,9 @@ async def test_async_embedding_azure():
customHandler_success = CompletionCustomHandler() customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success] litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(model="azure/azure-embedding-model", response = await litellm.aembedding(
input=["good morning from litellm"]) model="azure/azure-embedding-model", input=["good morning from litellm"]
)
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}") print(f"customHandler_success.states: {customHandler_success.states}")
@ -532,9 +601,11 @@ async def test_async_embedding_azure():
# test failure callback # test failure callback
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
try: try:
response = await litellm.aembedding(model="azure/azure-embedding-model", response = await litellm.aembedding(
model="azure/azure-embedding-model",
input=["good morning from litellm"], input=["good morning from litellm"],
api_key="my-bad-key") api_key="my-bad-key",
)
except: except:
pass pass
await asyncio.sleep(1) await asyncio.sleep(1)
@ -545,8 +616,10 @@ async def test_async_embedding_azure():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_azure()) # asyncio.run(test_async_embedding_azure())
## Test Bedrock + Async ## Test Bedrock + Async
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_embedding_bedrock(): async def test_async_embedding_bedrock():
@ -555,8 +628,11 @@ async def test_async_embedding_bedrock():
customHandler_failure = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success] litellm.callbacks = [customHandler_success]
litellm.set_verbose = True litellm.set_verbose = True
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3", response = await litellm.aembedding(
input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2") model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"],
aws_region_name="os.environ/AWS_REGION_NAME_2",
)
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}") print(f"customHandler_success.states: {customHandler_success.states}")
@ -565,9 +641,11 @@ async def test_async_embedding_bedrock():
# test failure callback # test failure callback
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
try: try:
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3", response = await litellm.aembedding(
model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"], input=["good morning from litellm"],
aws_region_name="my-bad-region") aws_region_name="my-bad-region",
)
except: except:
pass pass
await asyncio.sleep(1) await asyncio.sleep(1)
@ -578,54 +656,72 @@ async def test_async_embedding_bedrock():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_bedrock()) # asyncio.run(test_async_embedding_bedrock())
# CACHING # CACHING
## Test Azure - completion, embedding ## Test Azure - completion, embedding
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_completion_azure_caching(): async def test_async_completion_azure_caching():
customHandler_caching = CompletionCustomHandler() customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching] litellm.callbacks = [customHandler_caching]
unique_time = time.time() unique_time = time.time()
response1 = await litellm.acompletion(model="azure/chatgpt-v-2", response1 = await litellm.acompletion(
messages=[{ model="azure/chatgpt-v-2",
"role": "user", messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
"content": f"Hi 👋 - i'm async azure {unique_time}" caching=True,
}], )
caching=True)
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}") print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(model="azure/chatgpt-v-2", response2 = await litellm.acompletion(
messages=[{ model="azure/chatgpt-v-2",
"role": "user", messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
"content": f"Hi 👋 - i'm async azure {unique_time}" caching=True,
}], )
caching=True)
await asyncio.sleep(1) # success callbacks are done in parallel await asyncio.sleep(1) # success callbacks are done in parallel
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}") print(
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success assert len(customHandler_caching.states) == 4 # pre, post, success, success
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_embedding_azure_caching(): async def test_async_embedding_azure_caching():
print("Testing custom callback input - Azure Caching") print("Testing custom callback input - Azure Caching")
customHandler_caching = CompletionCustomHandler() customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching] litellm.callbacks = [customHandler_caching]
unique_time = time.time() unique_time = time.time()
response1 = await litellm.aembedding(model="azure/azure-embedding-model", response1 = await litellm.aembedding(
model="azure/azure-embedding-model",
input=[f"good morning from litellm1 {unique_time}"], input=[f"good morning from litellm1 {unique_time}"],
caching=True) caching=True,
)
await asyncio.sleep(1) # set cache is async for aembedding() await asyncio.sleep(1) # set cache is async for aembedding()
response2 = await litellm.aembedding(model="azure/azure-embedding-model", response2 = await litellm.aembedding(
model="azure/azure-embedding-model",
input=[f"good morning from litellm1 {unique_time}"], input=[f"good morning from litellm1 {unique_time}"],
caching=True) caching=True,
)
await asyncio.sleep(1) # success callbacks are done in parallel await asyncio.sleep(1) # success callbacks are done in parallel
print(customHandler_caching.states) print(customHandler_caching.states)
assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success assert len(customHandler_caching.states) == 4 # pre, post, success, success
# asyncio.run( # asyncio.run(
# test_async_embedding_azure_caching() # test_async_embedding_azure_caching()
# ) # )

View file

@ -3,7 +3,8 @@
import sys, os, time, inspect, asyncio, traceback import sys, os, time, inspect, asyncio, traceback
from datetime import datetime from datetime import datetime
import pytest import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List from typing import Optional, Literal, List
from litellm import Router, Cache from litellm import Router, Cache
import litellm import litellm
@ -29,39 +30,61 @@ from litellm.integrations.custom_logger import CustomLogger
## 1. router.completion() + router.embeddings() ## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings ## 2. proxy.completions + proxy.embeddings
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
class CompletionCustomHandler(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
""" """
The set of expected inputs to a custom handler for a The set of expected inputs to a custom handler for a
""" """
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
self.errors = [] self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = [] self.states: Optional[
List[
Literal[
"sync_pre_api_call",
"async_pre_api_call",
"post_api_call",
"sync_stream",
"async_stream",
"sync_success",
"async_success",
"sync_failure",
"async_failure",
]
]
] = []
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
try: try:
print(f'received kwargs in pre-input: {kwargs}') print(f"received kwargs in pre-input: {kwargs}")
self.states.append("sync_pre_api_call") self.states.append("sync_pre_api_call")
## MODEL ## MODEL
assert isinstance(model, str) assert isinstance(model, str)
## MESSAGES ## MESSAGES
assert isinstance(messages, list) assert isinstance(messages, list)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
### ROUTER-SPECIFIC KWARGS ### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict) assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict) assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None))) assert isinstance(
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None))) kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict) assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
@ -77,26 +100,36 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT ## RESPONSE OBJECT
assert response_obj == None assert response_obj == None
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str)) assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response']) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(
assert isinstance(kwargs['log_event_type'], str) kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.iscoroutine(kwargs["original_response"])
or inspect.isasyncgen(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
### ROUTER-SPECIFIC KWARGS ### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict) assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict) assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None))) assert isinstance(
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None))) kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict) assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
@ -112,18 +145,29 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse) assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) assert isinstance(kwargs["messages"], list) and isinstance(
assert isinstance(kwargs['optional_params'], dict) kwargs["messages"][0], dict
assert isinstance(kwargs['litellm_params'], dict) )
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(kwargs["input"], list)
assert isinstance(kwargs['log_event_type'], str) and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -138,18 +182,25 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse) assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) assert isinstance(kwargs["messages"], list) and isinstance(
assert isinstance(kwargs['optional_params'], dict) kwargs["messages"][0], dict
assert isinstance(kwargs['litellm_params'], dict) )
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(kwargs["input"], list)
assert isinstance(kwargs['log_event_type'], str) and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
@ -165,18 +216,28 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT ## RESPONSE OBJECT
assert response_obj == None assert response_obj == None
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) assert isinstance(kwargs["messages"], list) and isinstance(
assert isinstance(kwargs['optional_params'], dict) kwargs["messages"][0], dict
assert isinstance(kwargs['litellm_params'], dict) )
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(kwargs["input"], list)
assert isinstance(kwargs['log_event_type'], str) and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
@ -200,20 +261,28 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## END TIME ## END TIME
assert isinstance(end_time, datetime) assert isinstance(end_time, datetime)
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)) assert isinstance(
response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)
)
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['input'], (list, dict, str)) assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(
assert isinstance(kwargs['log_event_type'], str) kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
### ROUTER-SPECIFIC KWARGS ### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict) assert isinstance(kwargs["litellm_params"]["metadata"], dict)
@ -221,8 +290,12 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict) assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None))) assert isinstance(
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None))) kwargs["litellm_params"]["proxy_server_request"], (str, type(None))
)
assert isinstance(
kwargs["litellm_params"]["preset_cache_key"], (str, type(None))
)
assert isinstance(kwargs["litellm_params"]["stream_response"], dict) assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
@ -239,22 +312,30 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
## RESPONSE OBJECT ## RESPONSE OBJECT
assert response_obj == None assert response_obj == None
## KWARGS ## KWARGS
assert isinstance(kwargs['model'], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs['messages'], list) assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs['start_time'], (datetime, type(None))) assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs['user'], (str, type(None))) assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs['input'], (list, str, dict)) assert isinstance(kwargs["input"], (list, str, dict))
assert isinstance(kwargs['api_key'], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None assert (
assert isinstance(kwargs['additional_args'], (dict, type(None))) isinstance(
assert isinstance(kwargs['log_event_type'], str) kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except: except:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc()) self.errors.append(traceback.format_exc())
# Simple Azure OpenAI call # Simple Azure OpenAI call
## COMPLETION ## COMPLETION
@pytest.mark.asyncio @pytest.mark.asyncio
@ -271,37 +352,39 @@ async def test_async_chat_azure():
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
] ]
router = Router(model_list=model_list) # type: ignore router = Router(model_list=model_list) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo", response = await router.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" )
}])
await asyncio.sleep(2) await asyncio.sleep(2)
assert len(customHandler_completion_azure_router.errors) == 0 assert len(customHandler_completion_azure_router.errors) == 0
assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success assert (
len(customHandler_completion_azure_router.states) == 3
) # pre, post, success
# streaming # streaming
litellm.callbacks = [customHandler_streaming_azure_router] litellm.callbacks = [customHandler_streaming_azure_router]
router2 = Router(model_list=model_list) # type: ignore router2 = Router(model_list=model_list) # type: ignore
response = await router2.acompletion(model="gpt-3.5-turbo", response = await router2.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" stream=True,
}], )
stream=True)
async for chunk in response: async for chunk in response:
print(f"async azure router chunk: {chunk}") print(f"async azure router chunk: {chunk}")
continue continue
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_streaming_azure_router.states}") print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
assert len(customHandler_streaming_azure_router.errors) == 0 assert len(customHandler_streaming_azure_router.errors) == 0
assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success assert (
len(customHandler_streaming_azure_router.states) >= 4
) # pre, post, stream (multiple times), success
# failure # failure
model_list = [ model_list = [
{ {
@ -310,20 +393,19 @@ async def test_async_chat_azure():
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": "my-bad-key", "api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
] ]
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore router3 = Router(model_list=model_list) # type: ignore
try: try:
response = await router3.acompletion(model="gpt-3.5-turbo", response = await router3.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" )
}])
print(f"response in router3 acompletion: {response}") print(f"response in router3 acompletion: {response}")
except: except:
pass pass
@ -335,6 +417,8 @@ async def test_async_chat_azure():
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure()) # asyncio.run(test_async_chat_azure())
## EMBEDDING ## EMBEDDING
@pytest.mark.asyncio @pytest.mark.asyncio
@ -350,15 +434,16 @@ async def test_async_embedding_azure():
"model": "azure/azure-embedding-model", "model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
] ]
router = Router(model_list=model_list) # type: ignore router = Router(model_list=model_list) # type: ignore
response = await router.aembedding(model="azure-embedding-model", response = await router.aembedding(
input=["hello from litellm!"]) model="azure-embedding-model", input=["hello from litellm!"]
)
await asyncio.sleep(2) await asyncio.sleep(2)
assert len(customHandler.errors) == 0 assert len(customHandler.errors) == 0
assert len(customHandler.states) == 3 # pre, post, success assert len(customHandler.states) == 3 # pre, post, success
@ -370,17 +455,18 @@ async def test_async_embedding_azure():
"model": "azure/azure-embedding-model", "model": "azure/azure-embedding-model",
"api_key": "my-bad-key", "api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
] ]
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore router3 = Router(model_list=model_list) # type: ignore
try: try:
response = await router3.aembedding(model="azure-embedding-model", response = await router3.aembedding(
input=["hello from litellm!"]) model="azure-embedding-model", input=["hello from litellm!"]
)
print(f"response in router3 aembedding: {response}") print(f"response in router3 aembedding: {response}")
except: except:
pass pass
@ -392,6 +478,8 @@ async def test_async_embedding_azure():
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_embedding_azure()) # asyncio.run(test_async_embedding_azure())
# Azure OpenAI call w/ Fallbacks # Azure OpenAI call w/ Fallbacks
## COMPLETION ## COMPLETION
@ -408,10 +496,10 @@ async def test_async_chat_azure_with_fallbacks():
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": "my-bad-key", "api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
{ {
"model_name": "gpt-3.5-turbo-16k", "model_name": "gpt-3.5-turbo-16k",
@ -419,31 +507,40 @@ async def test_async_chat_azure_with_fallbacks():
"model": "gpt-3.5-turbo-16k", "model": "gpt-3.5-turbo-16k",
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
} },
] ]
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo", response = await router.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" )
}])
await asyncio.sleep(2) await asyncio.sleep(2)
print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}") print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}")
assert len(customHandler_fallbacks.errors) == 0 assert len(customHandler_fallbacks.errors) == 0
assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success assert (
len(customHandler_fallbacks.states) == 6
) # pre, post, failure, pre, post, success
litellm.callbacks = [] litellm.callbacks = []
except Exception as e: except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}") print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure_with_fallbacks()) # asyncio.run(test_async_chat_azure_with_fallbacks())
# CACHING # CACHING
## Test Azure - completion, embedding ## Test Azure - completion, embedding
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_completion_azure_caching(): async def test_async_completion_azure_caching():
customHandler_caching = CompletionCustomHandler() customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching] litellm.callbacks = [customHandler_caching]
unique_time = time.time() unique_time = time.time()
model_list = [ model_list = [
@ -453,10 +550,10 @@ async def test_async_completion_azure_caching():
"model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE") "api_base": os.getenv("AZURE_API_BASE"),
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
}, },
{ {
"model_name": "gpt-3.5-turbo-16k", "model_name": "gpt-3.5-turbo-16k",
@ -464,25 +561,25 @@ async def test_async_completion_azure_caching():
"model": "gpt-3.5-turbo-16k", "model": "gpt-3.5-turbo-16k",
}, },
"tpm": 240000, "tpm": 240000,
"rpm": 1800 "rpm": 1800,
} },
] ]
router = Router(model_list=model_list) # type: ignore router = Router(model_list=model_list) # type: ignore
response1 = await router.acompletion(model="gpt-3.5-turbo", response1 = await router.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
"content": f"Hi 👋 - i'm async azure {unique_time}" caching=True,
}], )
caching=True)
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}") print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await router.acompletion(model="gpt-3.5-turbo", response2 = await router.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
"content": f"Hi 👋 - i'm async azure {unique_time}" caching=True,
}], )
caching=True)
await asyncio.sleep(1) # success callbacks are done in parallel await asyncio.sleep(1) # success callbacks are done in parallel
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}") print(
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success assert len(customHandler_caching.states) == 4 # pre, post, success, success

View file

@ -1,12 +1,14 @@
import sys import sys
import os import os
import io, asyncio import io, asyncio
# import logging # import logging
# logging.basicConfig(level=logging.DEBUG) # logging.basicConfig(level=logging.DEBUG)
sys.path.insert(0, os.path.abspath('../..')) sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion from litellm import completion
import litellm import litellm
litellm.num_retries = 3 litellm.num_retries = 3
import time, random import time, random
@ -29,11 +31,14 @@ def pre_request():
import re import re
def verify_log_file(log_file_path):
with open(log_file_path, 'r') as log_file:
def verify_log_file(log_file_path):
with open(log_file_path, "r") as log_file:
log_content = log_file.read() log_content = log_file.read()
print(f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content) print(
f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content
)
# Define the pattern to search for in the log file # Define the pattern to search for in the log file
pattern = r"Response from DynamoDB:{.*?}" pattern = r"Response from DynamoDB:{.*?}"
@ -50,7 +55,11 @@ def verify_log_file(log_file_path):
print(f"Total occurrences of specified response: {len(matches)}") print(f"Total occurrences of specified response: {len(matches)}")
# Count the occurrences of successful responses (status code 200 or 201) # Count the occurrences of successful responses (status code 200 or 201)
success_count = sum(1 for match in matches if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match) success_count = sum(
1
for match in matches
if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match
)
# Print the count of successful responses # Print the count of successful responses
print(f"Count of successful responses from DynamoDB: {success_count}") print(f"Count of successful responses from DynamoDB: {success_count}")
@ -69,41 +78,41 @@ def test_dynamo_logging():
litellm.set_verbose = True litellm.set_verbose = True
original_stdout, log_file, file_name = pre_request() original_stdout, log_file, file_name = pre_request()
print("Testing async dynamoDB logging") print("Testing async dynamoDB logging")
async def _test(): async def _test():
return await litellm.acompletion( return await litellm.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}], messages=[{"role": "user", "content": "This is a test"}],
max_tokens=100, max_tokens=100,
temperature=0.7, temperature=0.7,
user = "ishaan-2" user="ishaan-2",
) )
response = asyncio.run(_test()) response = asyncio.run(_test())
print(f"response: {response}") print(f"response: {response}")
# streaming + async # streaming + async
async def _test2(): async def _test2():
response = await litellm.acompletion( response = await litellm.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content":"This is a test"}], messages=[{"role": "user", "content": "This is a test"}],
max_tokens=10, max_tokens=10,
temperature=0.7, temperature=0.7,
user = "ishaan-2", user="ishaan-2",
stream=True stream=True,
) )
async for chunk in response: async for chunk in response:
pass pass
asyncio.run(_test2()) asyncio.run(_test2())
# aembedding() # aembedding()
async def _test3(): async def _test3():
return await litellm.aembedding( return await litellm.aembedding(
model="text-embedding-ada-002", model="text-embedding-ada-002", input=["hi"], user="ishaan-2"
input = ["hi"],
user = "ishaan-2"
) )
response = asyncio.run(_test3()) response = asyncio.run(_test3())
time.sleep(1) time.sleep(1)
except Exception as e: except Exception as e:
@ -117,4 +126,5 @@ def test_dynamo_logging():
verify_log_file(file_name) verify_log_file(file_name)
print("Passed! Testing async dynamoDB logging") print("Passed! Testing async dynamoDB logging")
# test_dynamo_logging_async() # test_dynamo_logging_async()

View file

@ -14,17 +14,18 @@ from litellm import embedding, completion
litellm.set_verbose = False litellm.set_verbose = False
def test_openai_embedding(): def test_openai_embedding():
try: try:
litellm.set_verbose=True litellm.set_verbose = True
response = embedding( response = embedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
metadata = {"anything": "good day"} metadata={"anything": "good day"},
) )
litellm_response = dict(response) litellm_response = dict(response)
litellm_response_keys = set(litellm_response.keys()) litellm_response_keys = set(litellm_response.keys())
litellm_response_keys.discard('_response_ms') litellm_response_keys.discard("_response_ms")
print(litellm_response_keys) print(litellm_response_keys)
print("LiteLLM Response\n") print("LiteLLM Response\n")
@ -32,21 +33,30 @@ def test_openai_embedding():
# same request with OpenAI 1.0+ # same request with OpenAI 1.0+
import openai import openai
client = openai.OpenAI(api_key=os.environ['OPENAI_API_KEY'])
client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
response = client.embeddings.create( response = client.embeddings.create(
model="text-embedding-ada-002", input=["good morning from litellm", "this is another item"] model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"],
) )
response = dict(response) response = dict(response)
openai_response_keys = set(response.keys()) openai_response_keys = set(response.keys())
print(openai_response_keys) print(openai_response_keys)
assert litellm_response_keys == openai_response_keys # ENSURE the Keys in litellm response is exactly what the openai package returns assert (
assert len(litellm_response["data"]) == 2 # expect two embedding responses from litellm_response since input had two litellm_response_keys == openai_response_keys
) # ENSURE the Keys in litellm response is exactly what the openai package returns
assert (
len(litellm_response["data"]) == 2
) # expect two embedding responses from litellm_response since input had two
print(openai_response_keys) print(openai_response_keys)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_openai_embedding() # test_openai_embedding()
def test_openai_azure_embedding_simple(): def test_openai_azure_embedding_simple():
try: try:
response = embedding( response = embedding(
@ -55,12 +65,15 @@ def test_openai_azure_embedding_simple():
) )
print(response) print(response)
response_keys = set(dict(response).keys()) response_keys = set(dict(response).keys())
response_keys.discard('_response_ms') response_keys.discard("_response_ms")
assert set(["usage", "model", "object", "data"]) == set(response_keys) #assert litellm response has expected keys from OpenAI embedding response assert set(["usage", "model", "object", "data"]) == set(
response_keys
) # assert litellm response has expected keys from OpenAI embedding response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_openai_azure_embedding_simple() # test_openai_azure_embedding_simple()
@ -69,41 +82,50 @@ def test_openai_azure_embedding_timeouts():
response = embedding( response = embedding(
model="azure/azure-embedding-model", model="azure/azure-embedding-model",
input=["good morning from litellm"], input=["good morning from litellm"],
timeout=0.00001 timeout=0.00001,
) )
print(response) print(response)
except openai.APITimeoutError: except openai.APITimeoutError:
print("Good job got timeout error!") print("Good job got timeout error!")
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}") pytest.fail(
f"Expected timeout error, did not get the correct error. Instead got {e}"
)
# test_openai_azure_embedding_timeouts() # test_openai_azure_embedding_timeouts()
def test_openai_embedding_timeouts(): def test_openai_embedding_timeouts():
try: try:
response = embedding( response = embedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=["good morning from litellm"], input=["good morning from litellm"],
timeout=0.00001 timeout=0.00001,
) )
print(response) print(response)
except openai.APITimeoutError: except openai.APITimeoutError:
print("Good job got OpenAI timeout error!") print("Good job got OpenAI timeout error!")
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Expected timeout error, did not get the correct error. Instead got {e}") pytest.fail(
f"Expected timeout error, did not get the correct error. Instead got {e}"
)
# test_openai_embedding_timeouts() # test_openai_embedding_timeouts()
def test_openai_azure_embedding(): def test_openai_azure_embedding():
try: try:
api_key = os.environ['AZURE_API_KEY'] api_key = os.environ["AZURE_API_KEY"]
api_base = os.environ['AZURE_API_BASE'] api_base = os.environ["AZURE_API_BASE"]
api_version = os.environ['AZURE_API_VERSION'] api_version = os.environ["AZURE_API_VERSION"]
os.environ['AZURE_API_VERSION'] = "" os.environ["AZURE_API_VERSION"] = ""
os.environ['AZURE_API_BASE'] = "" os.environ["AZURE_API_BASE"] = ""
os.environ['AZURE_API_KEY'] = "" os.environ["AZURE_API_KEY"] = ""
response = embedding( response = embedding(
model="azure/azure-embedding-model", model="azure/azure-embedding-model",
@ -114,33 +136,37 @@ def test_openai_azure_embedding():
) )
print(response) print(response)
os.environ["AZURE_API_VERSION"] = api_version
os.environ['AZURE_API_VERSION'] = api_version os.environ["AZURE_API_BASE"] = api_base
os.environ['AZURE_API_BASE'] = api_base os.environ["AZURE_API_KEY"] = api_key
os.environ['AZURE_API_KEY'] = api_key
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_openai_azure_embedding() # test_openai_azure_embedding()
# test_openai_embedding() # test_openai_embedding()
def test_cohere_embedding(): def test_cohere_embedding():
try: try:
# litellm.set_verbose=True # litellm.set_verbose=True
response = embedding( response = embedding(
model="embed-english-v2.0", input=["good morning from litellm", "this is another item"] model="embed-english-v2.0",
input=["good morning from litellm", "this is another item"],
) )
print(f"response:", response) print(f"response:", response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_cohere_embedding() # test_cohere_embedding()
def test_cohere_embedding3(): def test_cohere_embedding3():
try: try:
litellm.set_verbose=True litellm.set_verbose = True
response = embedding( response = embedding(
model="embed-english-v3.0", model="embed-english-v3.0",
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
@ -149,97 +175,135 @@ def test_cohere_embedding3():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_cohere_embedding3() # test_cohere_embedding3()
def test_bedrock_embedding_titan(): def test_bedrock_embedding_titan():
try: try:
litellm.set_verbose=True litellm.set_verbose = True
response = embedding( response = embedding(
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data", model="amazon.titan-embed-text-v1",
"lets test a second string for good measure"] input=[
"good morning from litellm, attempting to embed data",
"lets test a second string for good measure",
],
) )
print(f"response:", response) print(f"response:", response)
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list" assert isinstance(
print(f"type of first embedding:", type(response['data'][0]['embedding'][0])) response["data"][0]["embedding"], list
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats" ), "Expected response to be a list"
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
assert all(
isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats"
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_bedrock_embedding_titan() test_bedrock_embedding_titan()
def test_bedrock_embedding_cohere(): def test_bedrock_embedding_cohere():
try: try:
litellm.set_verbose=False litellm.set_verbose = False
response = embedding( response = embedding(
model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"], model="cohere.embed-multilingual-v3",
aws_region_name="os.environ/AWS_REGION_NAME_2" input=[
"good morning from litellm, attempting to embed data",
"lets test a second string for good measure",
],
aws_region_name="os.environ/AWS_REGION_NAME_2",
) )
assert isinstance(response['data'][0]['embedding'], list), "Expected response to be a list" assert isinstance(
print(f"type of first embedding:", type(response['data'][0]['embedding'][0])) response["data"][0]["embedding"], list
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats" ), "Expected response to be a list"
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
assert all(
isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats"
# print(f"response:", response) # print(f"response:", response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_bedrock_embedding_cohere() # test_bedrock_embedding_cohere()
# comment out hf tests - since hf endpoints are unstable # comment out hf tests - since hf endpoints are unstable
def test_hf_embedding(): def test_hf_embedding():
try: try:
# huggingface/microsoft/codebert-base # huggingface/microsoft/codebert-base
# huggingface/facebook/bart-large # huggingface/facebook/bart-large
response = embedding( response = embedding(
model="huggingface/sentence-transformers/all-MiniLM-L6-v2", input=["good morning from litellm", "this is another item"] model="huggingface/sentence-transformers/all-MiniLM-L6-v2",
input=["good morning from litellm", "this is another item"],
) )
print(f"response:", response) print(f"response:", response)
except Exception as e: except Exception as e:
# Note: Huggingface inference API is unstable and fails with "model loading errors all the time" # Note: Huggingface inference API is unstable and fails with "model loading errors all the time"
pass pass
# test_hf_embedding() # test_hf_embedding()
# test async embeddings # test async embeddings
def test_aembedding(): def test_aembedding():
try: try:
import asyncio import asyncio
async def embedding_call(): async def embedding_call():
try: try:
response = await litellm.aembedding( response = await litellm.aembedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input=["good morning from litellm", "this is another item"] input=["good morning from litellm", "this is another item"],
) )
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
asyncio.run(embedding_call()) asyncio.run(embedding_call())
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_aembedding() # test_aembedding()
def test_aembedding_azure(): def test_aembedding_azure():
try: try:
import asyncio import asyncio
async def embedding_call(): async def embedding_call():
try: try:
response = await litellm.aembedding( response = await litellm.aembedding(
model="azure/azure-embedding-model", model="azure/azure-embedding-model",
input=["good morning from litellm", "this is another item"] input=["good morning from litellm", "this is another item"],
) )
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
asyncio.run(embedding_call()) asyncio.run(embedding_call())
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_aembedding_azure() # test_aembedding_azure()
def test_sagemaker_embeddings(): def test_sagemaker_embeddings():
try: try:
response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"]) response = litellm.embedding(
model="sagemaker/berri-benchmarking-gpt-j-6b-fp16",
input=["good morning from litellm", "this is another item"],
)
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_sagemaker_embeddings() # test_sagemaker_embeddings()
# def local_proxy_embeddings(): # def local_proxy_embeddings():
# litellm.set_verbose=True # litellm.set_verbose=True

View file

@ -11,17 +11,18 @@ import litellm
from litellm import ( from litellm import (
embedding, embedding,
completion, completion,
# AuthenticationError, # AuthenticationError,
ContextWindowExceededError, ContextWindowExceededError,
# RateLimitError, # RateLimitError,
# ServiceUnavailableError, # ServiceUnavailableError,
# OpenAIError, # OpenAIError,
) )
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import pytest import pytest
litellm.vertex_project = "pathrise-convert-1606954137718" litellm.vertex_project = "pathrise-convert-1606954137718"
litellm.vertex_location = "us-central1" litellm.vertex_location = "us-central1"
litellm.num_retries=0 litellm.num_retries = 0
# litellm.failure_callback = ["sentry"] # litellm.failure_callback = ["sentry"]
#### What this tests #### #### What this tests ####
@ -36,6 +37,7 @@ litellm.num_retries=0
models = ["command-nightly"] models = ["command-nightly"]
# Test 1: Context Window Errors # Test 1: Context Window Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window(model): def test_context_window(model):
@ -56,13 +58,23 @@ def test_context_window(model):
print(f"{e}") print(f"{e}")
pytest.fail(f"An error occcurred - {e}") pytest.fail(f"An error occcurred - {e}")
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window_with_fallbacks(model): def test_context_window_with_fallbacks(model):
ctx_window_fallback_dict = {"command-nightly": "claude-2", "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k", "azure/chatgpt-v-2": "gpt-3.5-turbo-16k"} ctx_window_fallback_dict = {
"command-nightly": "claude-2",
"gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k",
"azure/chatgpt-v-2": "gpt-3.5-turbo-16k",
}
sample_text = "how does a court case get to the Supreme Court?" * 1000 sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict) completion(
model=model,
messages=messages,
context_window_fallback_dict=ctx_window_fallback_dict,
)
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model) # test_context_window(model=model)
@ -98,7 +110,9 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["AI21_API_KEY"] = "bad-key" os.environ["AI21_API_KEY"] = "bad-key"
elif "togethercomputer" in model: elif "togethercomputer" in model:
temporary_key = os.environ["TOGETHERAI_API_KEY"] temporary_key = os.environ["TOGETHERAI_API_KEY"]
os.environ["TOGETHERAI_API_KEY"] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a" os.environ[
"TOGETHERAI_API_KEY"
] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a"
elif model in litellm.openrouter_models: elif model in litellm.openrouter_models:
temporary_key = os.environ["OPENROUTER_API_KEY"] temporary_key = os.environ["OPENROUTER_API_KEY"]
os.environ["OPENROUTER_API_KEY"] = "bad-key" os.environ["OPENROUTER_API_KEY"] = "bad-key"
@ -115,9 +129,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
temporary_key = os.environ["REPLICATE_API_KEY"] temporary_key = os.environ["REPLICATE_API_KEY"]
os.environ["REPLICATE_API_KEY"] = "bad-key" os.environ["REPLICATE_API_KEY"] = "bad-key"
print(f"model: {model}") print(f"model: {model}")
response = completion( response = completion(model=model, messages=messages)
model=model, messages=messages
)
print(f"response: {response}") print(f"response: {response}")
except AuthenticationError as e: except AuthenticationError as e:
print(f"AuthenticationError Caught Exception - {str(e)}") print(f"AuthenticationError Caught Exception - {str(e)}")
@ -148,7 +160,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["REPLICATE_API_KEY"] = temporary_key os.environ["REPLICATE_API_KEY"] = temporary_key
elif "j2" in model: elif "j2" in model:
os.environ["AI21_API_KEY"] = temporary_key os.environ["AI21_API_KEY"] = temporary_key
elif ("togethercomputer" in model): elif "togethercomputer" in model:
os.environ["TOGETHERAI_API_KEY"] = temporary_key os.environ["TOGETHERAI_API_KEY"] = temporary_key
elif model in litellm.aleph_alpha_models: elif model in litellm.aleph_alpha_models:
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
@ -160,10 +172,12 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
return return
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# invalid_auth(model=model) # invalid_auth(model=model)
# invalid_auth(model="command-nightly") # invalid_auth(model="command-nightly")
# Test 3: Invalid Request Error # Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_invalid_request_error(model): def test_invalid_request_error(model):
@ -173,23 +187,18 @@ def test_invalid_request_error(model):
completion(model=model, messages=messages, max_tokens="hello world") completion(model=model, messages=messages, max_tokens="hello world")
def test_completion_azure_exception(): def test_completion_azure_exception():
try: try:
import openai import openai
print("azure gpt-3.5 test\n\n") print("azure gpt-3.5 test\n\n")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
old_azure_key = os.environ["AZURE_API_KEY"] old_azure_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "good morning" os.environ["AZURE_API_KEY"] = "good morning"
response = completion( response = completion(
model="azure/chatgpt-v-2", model="azure/chatgpt-v-2",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
os.environ["AZURE_API_KEY"] = old_azure_key os.environ["AZURE_API_KEY"] = old_azure_key
print(f"response: {response}") print(f"response: {response}")
@ -199,25 +208,24 @@ def test_completion_azure_exception():
print("good job got the correct error for azure when key not set") print("good job got the correct error for azure when key not set")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_azure_exception() # test_completion_azure_exception()
async def asynctest_completion_azure_exception(): async def asynctest_completion_azure_exception():
try: try:
import openai import openai
import litellm import litellm
print("azure gpt-3.5 test\n\n") print("azure gpt-3.5 test\n\n")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
old_azure_key = os.environ["AZURE_API_KEY"] old_azure_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "good morning" os.environ["AZURE_API_KEY"] = "good morning"
response = await litellm.acompletion( response = await litellm.acompletion(
model="azure/chatgpt-v-2", model="azure/chatgpt-v-2",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
print(f"response: {response}") print(f"response: {response}")
print(response) print(response)
@ -229,6 +237,8 @@ async def asynctest_completion_azure_exception():
print("Got wrong exception") print("Got wrong exception")
print("exception", e) print("exception", e)
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# import asyncio # import asyncio
# asyncio.run( # asyncio.run(
# asynctest_completion_azure_exception() # asynctest_completion_azure_exception()
@ -239,19 +249,17 @@ def asynctest_completion_openai_exception_bad_model():
try: try:
import openai import openai
import litellm, asyncio import litellm, asyncio
print("azure exception bad model\n\n") print("azure exception bad model\n\n")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
async def test(): async def test():
response = await litellm.acompletion( response = await litellm.acompletion(
model="openai/gpt-6", model="openai/gpt-6",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
asyncio.run(test()) asyncio.run(test())
except openai.NotFoundError: except openai.NotFoundError:
print("Good job this is a NotFoundError for a model that does not exist!") print("Good job this is a NotFoundError for a model that does not exist!")
@ -261,27 +269,25 @@ def asynctest_completion_openai_exception_bad_model():
assert isinstance(e, openai.BadRequestError) assert isinstance(e, openai.BadRequestError)
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# asynctest_completion_openai_exception_bad_model()
# asynctest_completion_openai_exception_bad_model()
def asynctest_completion_azure_exception_bad_model(): def asynctest_completion_azure_exception_bad_model():
try: try:
import openai import openai
import litellm, asyncio import litellm, asyncio
print("azure exception bad model\n\n") print("azure exception bad model\n\n")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
async def test(): async def test():
response = await litellm.acompletion( response = await litellm.acompletion(
model="azure/gpt-12", model="azure/gpt-12",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
asyncio.run(test()) asyncio.run(test())
except openai.NotFoundError: except openai.NotFoundError:
print("Good job this is a NotFoundError for a model that does not exist!") print("Good job this is a NotFoundError for a model that does not exist!")
@ -290,25 +296,23 @@ def asynctest_completion_azure_exception_bad_model():
print("Raised wrong type of exception", type(e)) print("Raised wrong type of exception", type(e))
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# asynctest_completion_azure_exception_bad_model() # asynctest_completion_azure_exception_bad_model()
def test_completion_openai_exception(): def test_completion_openai_exception():
# test if openai:gpt raises openai.AuthenticationError # test if openai:gpt raises openai.AuthenticationError
try: try:
import openai import openai
print("openai gpt-3.5 test\n\n") print("openai gpt-3.5 test\n\n")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
old_azure_key = os.environ["OPENAI_API_KEY"] old_azure_key = os.environ["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = "good morning" os.environ["OPENAI_API_KEY"] = "good morning"
response = completion( response = completion(
model="gpt-4", model="gpt-4",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
print(f"response: {response}") print(f"response: {response}")
print(response) print(response)
@ -317,25 +321,24 @@ def test_completion_openai_exception():
print("OpenAI: good job got the correct error for openai when key not set") print("OpenAI: good job got the correct error for openai when key not set")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_openai_exception() # test_completion_openai_exception()
def test_completion_mistral_exception(): def test_completion_mistral_exception():
# test if mistral/mistral-tiny raises openai.AuthenticationError # test if mistral/mistral-tiny raises openai.AuthenticationError
try: try:
import openai import openai
print("Testing mistral ai exception mapping") print("Testing mistral ai exception mapping")
litellm.set_verbose=True litellm.set_verbose = True
## Test azure call ## Test azure call
old_azure_key = os.environ["MISTRAL_API_KEY"] old_azure_key = os.environ["MISTRAL_API_KEY"]
os.environ["MISTRAL_API_KEY"] = "good morning" os.environ["MISTRAL_API_KEY"] = "good morning"
response = completion( response = completion(
model="mistral/mistral-tiny", model="mistral/mistral-tiny",
messages=[ messages=[{"role": "user", "content": "hello"}],
{
"role": "user",
"content": "hello"
}
],
) )
print(f"response: {response}") print(f"response: {response}")
print(response) print(response)
@ -344,11 +347,11 @@ def test_completion_mistral_exception():
print("good job got the correct error for openai when key not set") print("good job got the correct error for openai when key not set")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_mistral_exception() # test_completion_mistral_exception()
# # test_invalid_request_error(model="command-nightly") # # test_invalid_request_error(model="command-nightly")
# # Test 3: Rate Limit Errors # # Test 3: Rate Limit Errors
# def test_model_call(model): # def test_model_call(model):

View file

@ -13,6 +13,7 @@ import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
import pytest import pytest
litellm.num_retries = 0 litellm.num_retries = 0
litellm.cache = None litellm.cache = None
# litellm.set_verbose=True # litellm.set_verbose=True
@ -20,23 +21,32 @@ import json
# litellm.success_callback = ["langfuse"] # litellm.success_callback = ["langfuse"]
def get_current_weather(location, unit="fahrenheit"): def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location""" """Get the current weather in a given location"""
if "tokyo" in location.lower(): if "tokyo" in location.lower():
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"}) return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"})
elif "san francisco" in location.lower(): elif "san francisco" in location.lower():
return json.dumps({"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}) return json.dumps(
{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}
)
elif "paris" in location.lower(): elif "paris" in location.lower():
return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"}) return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"})
else: else:
return json.dumps({"location": location, "temperature": "unknown"}) return json.dumps({"location": location, "temperature": "unknown"})
# Example dummy function hard coded to return the same weather # Example dummy function hard coded to return the same weather
# In production, this could be your backend API or an external API # In production, this could be your backend API or an external API
def test_parallel_function_call(): def test_parallel_function_call():
try: try:
# Step 1: send the conversation and available functions to the model # Step 1: send the conversation and available functions to the model
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
}
]
tools = [ tools = [
{ {
"type": "function", "type": "function",
@ -50,7 +60,10 @@ def test_parallel_function_call():
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA", "description": "The city and state, e.g. San Francisco, CA",
}, },
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, "unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
}, },
"required": ["location"], "required": ["location"],
}, },
@ -69,7 +82,9 @@ def test_parallel_function_call():
print("length of tool calls", len(tool_calls)) print("length of tool calls", len(tool_calls))
print("Expecting there to be 3 tool calls") print("Expecting there to be 3 tool calls")
assert len(tool_calls) > 1 # this has to call the function for SF, Tokyo and parise assert (
len(tool_calls) > 1
) # this has to call the function for SF, Tokyo and parise
# Step 2: check if the model wanted to call a function # Step 2: check if the model wanted to call a function
if tool_calls: if tool_calls:
@ -78,7 +93,9 @@ def test_parallel_function_call():
available_functions = { available_functions = {
"get_current_weather": get_current_weather, "get_current_weather": get_current_weather,
} # only one function in this example, but you can have multiple } # only one function in this example, but you can have multiple
messages.append(response_message) # extend conversation with assistant's reply messages.append(
response_message
) # extend conversation with assistant's reply
print("Response message\n", response_message) print("Response message\n", response_message)
# Step 4: send the info for each function call and function response to the model # Step 4: send the info for each function call and function response to the model
for tool_call in tool_calls: for tool_call in tool_calls:
@ -99,25 +116,26 @@ def test_parallel_function_call():
) # extend conversation with function response ) # extend conversation with function response
print(f"messages: {messages}") print(f"messages: {messages}")
second_response = litellm.completion( second_response = litellm.completion(
model="gpt-3.5-turbo-1106", model="gpt-3.5-turbo-1106", messages=messages, temperature=0.2, seed=22
messages=messages,
temperature=0.2,
seed=22
) # get a new response from the model where it can see the function response ) # get a new response from the model where it can see the function response
print("second response\n", second_response) print("second response\n", second_response)
return second_response return second_response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_parallel_function_call() test_parallel_function_call()
def test_parallel_function_call_stream(): def test_parallel_function_call_stream():
try: try:
# Step 1: send the conversation and available functions to the model # Step 1: send the conversation and available functions to the model
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris?",
}
]
tools = [ tools = [
{ {
"type": "function", "type": "function",
@ -131,7 +149,10 @@ def test_parallel_function_call_stream():
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA", "description": "The city and state, e.g. San Francisco, CA",
}, },
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, "unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
}, },
"required": ["location"], "required": ["location"],
}, },
@ -144,7 +165,7 @@ def test_parallel_function_call_stream():
tools=tools, tools=tools,
stream=True, stream=True,
tool_choice="auto", # auto is default, but we'll be explicit tool_choice="auto", # auto is default, but we'll be explicit
complete_response = True complete_response=True,
) )
print("Response\n", response) print("Response\n", response)
# for chunk in response: # for chunk in response:
@ -154,7 +175,9 @@ def test_parallel_function_call_stream():
print("length of tool calls", len(tool_calls)) print("length of tool calls", len(tool_calls))
print("Expecting there to be 3 tool calls") print("Expecting there to be 3 tool calls")
assert len(tool_calls) > 1 # this has to call the function for SF, Tokyo and parise assert (
len(tool_calls) > 1
) # this has to call the function for SF, Tokyo and parise
# Step 2: check if the model wanted to call a function # Step 2: check if the model wanted to call a function
if tool_calls: if tool_calls:
@ -163,7 +186,9 @@ def test_parallel_function_call_stream():
available_functions = { available_functions = {
"get_current_weather": get_current_weather, "get_current_weather": get_current_weather,
} # only one function in this example, but you can have multiple } # only one function in this example, but you can have multiple
messages.append(response_message) # extend conversation with assistant's reply messages.append(
response_message
) # extend conversation with assistant's reply
print("Response message\n", response_message) print("Response message\n", response_message)
# Step 4: send the info for each function call and function response to the model # Step 4: send the info for each function call and function response to the model
for tool_call in tool_calls: for tool_call in tool_calls:
@ -184,14 +209,12 @@ def test_parallel_function_call_stream():
) # extend conversation with function response ) # extend conversation with function response
print(f"messages: {messages}") print(f"messages: {messages}")
second_response = litellm.completion( second_response = litellm.completion(
model="gpt-3.5-turbo-1106", model="gpt-3.5-turbo-1106", messages=messages, temperature=0.2, seed=22
messages=messages,
temperature=0.2,
seed=22
) # get a new response from the model where it can see the function response ) # get a new response from the model where it can see the function response
print("second response\n", second_response) print("second response\n", second_response)
return second_response return second_response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_parallel_function_call_stream() test_parallel_function_call_stream()

View file

@ -11,9 +11,11 @@ sys.path.insert(
import pytest import pytest
import litellm import litellm
def test_get_llm_provider(): def test_get_llm_provider():
_, response, _, _ = litellm.get_llm_provider(model="anthropic.claude-v2:1") _, response, _, _ = litellm.get_llm_provider(model="anthropic.claude-v2:1")
assert response == "bedrock" assert response == "bedrock"
test_get_llm_provider() test_get_llm_provider()

View file

@ -9,41 +9,60 @@ import litellm
from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models
import pytest import pytest
def test_get_gpt3_tokens(): def test_get_gpt3_tokens():
max_tokens = get_max_tokens("gpt-3.5-turbo") max_tokens = get_max_tokens("gpt-3.5-turbo")
print(max_tokens) print(max_tokens)
assert max_tokens==4097 assert max_tokens == 4097
# print(results) # print(results)
test_get_gpt3_tokens() test_get_gpt3_tokens()
def test_get_palm_tokens(): def test_get_palm_tokens():
# # 🦄🦄🦄🦄🦄🦄🦄🦄 # # 🦄🦄🦄🦄🦄🦄🦄🦄
max_tokens = get_max_tokens("palm/chat-bison") max_tokens = get_max_tokens("palm/chat-bison")
assert max_tokens == 4096 assert max_tokens == 4096
print(max_tokens) print(max_tokens)
test_get_palm_tokens() test_get_palm_tokens()
def test_zephyr_hf_tokens(): def test_zephyr_hf_tokens():
max_tokens = get_max_tokens("huggingface/HuggingFaceH4/zephyr-7b-beta") max_tokens = get_max_tokens("huggingface/HuggingFaceH4/zephyr-7b-beta")
print(max_tokens) print(max_tokens)
assert max_tokens == 32768 assert max_tokens == 32768
test_zephyr_hf_tokens() test_zephyr_hf_tokens()
def test_cost_ft_gpt_35(): def test_cost_ft_gpt_35():
try: try:
# this tests if litellm.completion_cost can calculate cost for ft:gpt-3.5-turbo:my-org:custom_suffix:id # this tests if litellm.completion_cost can calculate cost for ft:gpt-3.5-turbo:my-org:custom_suffix:id
# it needs to lookup ft:gpt-3.5-turbo in the litellm model_cost map to get the correct cost # it needs to lookup ft:gpt-3.5-turbo in the litellm model_cost map to get the correct cost
from litellm import ModelResponse, Choices, Message from litellm import ModelResponse, Choices, Message
from litellm.utils import Usage from litellm.utils import Usage
resp = ModelResponse( resp = ModelResponse(
id='chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac', id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac",
choices=[Choices(finish_reason=None, index=0, choices=[
message=Message(content=' Sure! Here is a short poem about the sky:\n\nA canvas of blue, a', role='assistant'))], Choices(
finish_reason=None,
index=0,
message=Message(
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
role="assistant",
),
)
],
created=1700775391, created=1700775391,
model='ft:gpt-3.5-turbo:my-org:custom_suffix:id', model="ft:gpt-3.5-turbo:my-org:custom_suffix:id",
object='chat.completion', system_fingerprint=None, object="chat.completion",
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38) system_fingerprint=None,
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38),
) )
cost = litellm.completion_cost(completion_response=resp) cost = litellm.completion_cost(completion_response=resp)
@ -51,35 +70,58 @@ def test_cost_ft_gpt_35():
input_cost = model_cost["ft:gpt-3.5-turbo"]["input_cost_per_token"] input_cost = model_cost["ft:gpt-3.5-turbo"]["input_cost_per_token"]
output_cost = model_cost["ft:gpt-3.5-turbo"]["output_cost_per_token"] output_cost = model_cost["ft:gpt-3.5-turbo"]["output_cost_per_token"]
print(input_cost, output_cost) print(input_cost, output_cost)
expected_cost = (input_cost*resp.usage.prompt_tokens) + (output_cost*resp.usage.completion_tokens) expected_cost = (input_cost * resp.usage.prompt_tokens) + (
output_cost * resp.usage.completion_tokens
)
print("\n Excpected cost", expected_cost) print("\n Excpected cost", expected_cost)
assert cost == expected_cost assert cost == expected_cost
except Exception as e: except Exception as e:
pytest.fail(f"Cost Calc failed for ft:gpt-3.5. Expected {expected_cost}, Calculated cost {cost}") pytest.fail(
f"Cost Calc failed for ft:gpt-3.5. Expected {expected_cost}, Calculated cost {cost}"
)
test_cost_ft_gpt_35() test_cost_ft_gpt_35()
def test_cost_azure_gpt_35(): def test_cost_azure_gpt_35():
try: try:
# this tests if litellm.completion_cost can calculate cost for azure/chatgpt-deployment-2 which maps to azure/gpt-3.5-turbo # this tests if litellm.completion_cost can calculate cost for azure/chatgpt-deployment-2 which maps to azure/gpt-3.5-turbo
# for this test we check if passing `model` to completion_cost overrides the completion cost # for this test we check if passing `model` to completion_cost overrides the completion cost
from litellm import ModelResponse, Choices, Message from litellm import ModelResponse, Choices, Message
from litellm.utils import Usage from litellm.utils import Usage
resp = ModelResponse( resp = ModelResponse(
id='chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac', id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac",
choices=[Choices(finish_reason=None, index=0, choices=[
message=Message(content=' Sure! Here is a short poem about the sky:\n\nA canvas of blue, a', role='assistant'))], Choices(
model='azure/gpt-35-turbo', # azure always has model written like this finish_reason=None,
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38) index=0,
message=Message(
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
role="assistant",
),
)
],
model="azure/gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38),
) )
cost = litellm.completion_cost(completion_response=resp, model="azure/gpt-3.5-turbo") cost = litellm.completion_cost(
completion_response=resp, model="azure/gpt-3.5-turbo"
)
print("\n Calculated Cost for azure/gpt-3.5-turbo", cost) print("\n Calculated Cost for azure/gpt-3.5-turbo", cost)
input_cost = model_cost["azure/gpt-3.5-turbo"]["input_cost_per_token"] input_cost = model_cost["azure/gpt-3.5-turbo"]["input_cost_per_token"]
output_cost = model_cost["azure/gpt-3.5-turbo"]["output_cost_per_token"] output_cost = model_cost["azure/gpt-3.5-turbo"]["output_cost_per_token"]
expected_cost = (input_cost*resp.usage.prompt_tokens) + (output_cost*resp.usage.completion_tokens) expected_cost = (input_cost * resp.usage.prompt_tokens) + (
output_cost * resp.usage.completion_tokens
)
print("\n Excpected cost", expected_cost) print("\n Excpected cost", expected_cost)
assert cost == expected_cost assert cost == expected_cost
except Exception as e: except Exception as e:
pytest.fail(f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}") pytest.fail(
test_cost_azure_gpt_35() f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}"
)
test_cost_azure_gpt_35()

View file

@ -11,19 +11,38 @@ sys.path.insert(
import pytest import pytest
from litellm.llms.prompt_templates.factory import prompt_factory from litellm.llms.prompt_templates.factory import prompt_factory
def test_prompt_formatting(): def test_prompt_formatting():
try: try:
prompt = prompt_factory(model="mistralai/Mistral-7B-Instruct-v0.1", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}]) prompt = prompt_factory(
assert prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]" model="mistralai/Mistral-7B-Instruct-v0.1",
messages=[
{"role": "system", "content": "Be a good bot"},
{"role": "user", "content": "Hello world"},
],
)
assert (
prompt == "<s>[INST] Be a good bot [/INST]</s> [INST] Hello world [/INST]"
)
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
def test_prompt_formatting_custom_model(): def test_prompt_formatting_custom_model():
try: try:
prompt = prompt_factory(model="ehartford/dolphin-2.5-mixtral-8x7b", messages=[{"role": "system", "content": "Be a good bot"}, {"role": "user", "content": "Hello world"}], custom_llm_provider="huggingface") prompt = prompt_factory(
model="ehartford/dolphin-2.5-mixtral-8x7b",
messages=[
{"role": "system", "content": "Be a good bot"},
{"role": "user", "content": "Hello world"},
],
custom_llm_provider="huggingface",
)
print(f"prompt: {prompt}") print(f"prompt: {prompt}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# test_prompt_formatting_custom_model() # test_prompt_formatting_custom_model()
# def logger_fn(user_model_dict): # def logger_fn(user_model_dict):
# return # return

View file

@ -5,45 +5,71 @@ import sys, os
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
import logging import logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
load_dotenv() load_dotenv()
import os import os
import asyncio import asyncio
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
def test_image_generation_openai(): def test_image_generation_openai():
litellm.set_verbose = True litellm.set_verbose = True
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") response = litellm.image_generation(
prompt="A cute baby sea otter", model="dall-e-3"
)
print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0 assert len(response.data) > 0
# test_image_generation_openai() # test_image_generation_openai()
def test_image_generation_azure(): def test_image_generation_azure():
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview") response = litellm.image_generation(
prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview"
)
print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0 assert len(response.data) > 0
# test_image_generation_azure() # test_image_generation_azure()
def test_image_generation_azure_dall_e_3(): def test_image_generation_azure_dall_e_3():
litellm.set_verbose = True litellm.set_verbose = True
response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test", api_version="2023-12-01-preview", api_base=os.getenv("AZURE_SWEDEN_API_BASE"), api_key=os.getenv("AZURE_SWEDEN_API_KEY")) response = litellm.image_generation(
prompt="A cute baby sea otter",
model="azure/dall-e-3-test",
api_version="2023-12-01-preview",
api_base=os.getenv("AZURE_SWEDEN_API_BASE"),
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
)
print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0 assert len(response.data) > 0
# test_image_generation_azure_dall_e_3() # test_image_generation_azure_dall_e_3()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_image_generation_openai(): async def test_async_image_generation_openai():
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3") response = litellm.image_generation(
prompt="A cute baby sea otter", model="dall-e-3"
)
print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0 assert len(response.data) > 0
# asyncio.run(test_async_image_generation_openai()) # asyncio.run(test_async_image_generation_openai())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_image_generation_azure(): async def test_async_image_generation_azure():
response = await litellm.aimage_generation(prompt="A cute baby sea otter", model="azure/dall-e-3-test") response = await litellm.aimage_generation(
prompt="A cute baby sea otter", model="azure/dall-e-3-test"
)
print(f"response: {response}") print(f"response: {response}")

View file

@ -104,4 +104,3 @@
# # pytest.fail(f"Error occurred: {e}") # # pytest.fail(f"Error occurred: {e}")
# # test_openai_with_params() # # test_openai_with_params()

Some files were not shown because too many files have changed in this diff Show more